Adding `immutable` to global load ops. (#16295)
This allows us to locally know whether a load is of an immutable global
or not. We only ever make globals immutable during FoldGlobals and
perform the fixup of all loads we can.
On programs with lots of globals this speeds things up quite a bit by
dropping propagateLiveness from e.g. 20s->12s (sdxl_turbo_unet).
diff --git a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
index 0da087d..95f9381 100644
--- a/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/TFLite/Transforms/WrapEntryPoints.cpp
@@ -110,10 +110,8 @@
for (unsigned i = 0; i < tensorType.getRank(); ++i) {
if (tensorType.isDynamicDim(i)) {
auto globalOp = globalOps[dynamicDimIdx++];
- dims.push_back(
- builder
- .create<IREE::Util::GlobalLoadOp>(globalOp.getLoc(), globalOp)
- .getResult());
+ dims.push_back(globalOp.createLoadOp(globalOp.getLoc(), builder)
+ .getLoadedGlobalValue());
}
}
return dims;
@@ -207,8 +205,8 @@
auto entryBuilder = OpBuilder::atBlockBegin(&entryBlock);
// Go back and insert a check for the dirty flag.
- auto dirtyValue = entryBuilder.createOrFold<IREE::Util::GlobalLoadOp>(
- loc, dirtyGlobalOp.getType(), dirtyGlobalOp.getName());
+ auto dirtyValue =
+ dirtyGlobalOp.createLoadOp(loc, entryBuilder).getLoadedGlobalValue();
auto *recalculateBlock = calcFuncOp.addBlock();
auto *returnBlock = calcFuncOp.addBlock();
entryBuilder.create<mlir::cf::CondBranchOp>(loc, dirtyValue,
@@ -257,16 +255,15 @@
for (int64_t i = 0; i < outputDynamicDims.globalOps.size(); ++i) {
auto dimValue =
exitBuilder.createOrFold<tensor::DimOp>(exitLoc, outputValue, i);
- exitBuilder.create<IREE::Util::GlobalStoreOp>(
- exitLoc, dimValue, outputDynamicDims.globalOps[i].getSymName());
+ outputDynamicDims.globalOps[i].createStoreOp(exitLoc, dimValue,
+ exitBuilder);
}
}
// Clear the dirty flag now that the shapes have been updated.
auto falseValue =
exitBuilder.createOrFold<arith::ConstantIntOp>(exitLoc, 0, 1);
- exitBuilder.create<IREE::Util::GlobalStoreOp>(exitLoc, falseValue,
- dirtyGlobalOp.getSymName());
+ dirtyGlobalOp.createStoreOp(exitLoc, falseValue, exitBuilder);
exitBuilder.create<mlir::func::ReturnOp>(exitLoc);
returnOp.erase();
}
@@ -327,8 +324,9 @@
for (unsigned i = 0; i < shapeType.getRank(); ++i) {
Value dimValue;
if (shapeType.isDynamicDim(i)) {
- dimValue = builder.create<IREE::Util::GlobalLoadOp>(
- loc, dynamicDims.globalOps[dynamicDimIdx++]);
+ dimValue = dynamicDims.globalOps[dynamicDimIdx++]
+ .createLoadOp(loc, builder)
+ .getLoadedGlobalValue();
} else {
dimValue = builder.createOrFold<arith::ConstantIndexOp>(
loc, shapeType.getDimSize(i));
@@ -353,8 +351,8 @@
loc, builder.getIndexType(), listValue,
builder.createOrFold<arith::ConstantIndexOp>(loc, i))
.getResult();
- builder.create<IREE::Util::GlobalStoreOp>(
- loc, dimValue, dynamicDims.globalOps[dynamicDimIdx++].getSymName());
+ dynamicDims.globalOps[dynamicDimIdx++].createStoreOp(loc, dimValue,
+ builder);
}
}
@@ -422,8 +420,7 @@
// Set the dirty flag so that shapes get recalculated as needed.
auto exitBuilder = OpBuilder::atBlockBegin(exitBlock);
auto trueValue = exitBuilder.createOrFold<arith::ConstantIntOp>(loc, 1, 1);
- exitBuilder.create<IREE::Util::GlobalStoreOp>(loc, trueValue,
- dirtyGlobalOp.getName());
+ dirtyGlobalOp.createStoreOp(loc, trueValue, exitBuilder);
exitBuilder.create<mlir::func::ReturnOp>(loc);
}
@@ -522,8 +519,8 @@
llvm::zip_equal(entryBlock->getArguments(), inputDynamicDims)) {
SmallVector<Value> dynamicDims;
for (auto globalOp : inputDynamicDims.globalOps) {
- dynamicDims.push_back(entryBuilder.create<IREE::Util::GlobalLoadOp>(
- arg.getLoc(), globalOp));
+ dynamicDims.push_back(globalOp.createLoadOp(arg.getLoc(), entryBuilder)
+ .getLoadedGlobalValue());
}
callOperands.push_back(entryBuilder.create<IREE::HAL::TensorImportOp>(
arg.getLoc(), inputDynamicDims.tensorType, arg,
@@ -547,16 +544,15 @@
dynamicDims, /*target_storage=*/nullptr, /*name=*/nullptr));
for (auto [dynamicDim, globalOp] :
llvm::zip_equal(dynamicDims, outputDynamicDims.globalOps)) {
- entryBuilder.create<IREE::Util::GlobalStoreOp>(
- result.getLoc(), dynamicDim, globalOp.getSymName());
+ globalOp.createStoreOp(result.getLoc(), dynamicDim, entryBuilder);
}
}
// We recomputed the shapes of the outputs and can clear the dirty flag.
- entryBuilder.create<IREE::Util::GlobalStoreOp>(
+ dirtyGlobalOp.createStoreOp(
entryFuncOp.getLoc(),
entryBuilder.create<arith::ConstantIntOp>(entryFuncOp.getLoc(), 0, 1),
- dirtyGlobalOp.getSymName());
+ entryBuilder);
entryBuilder.create<mlir::func::ReturnOp>(entryFuncOp.getLoc(),
callResults);
diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
index 5e7afe8..463dc54 100644
--- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
+++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
@@ -104,7 +104,7 @@
ArgumentBinding(ElementsAttr attr)
: type(Type::ElementsAttr), elementsAttr(attr) {}
- ArgumentBinding(IREE::Util::GlobalOp globalOp)
+ ArgumentBinding(IREE::Util::GlobalOpInterface globalOp)
: type(Type::GlobalOp), globalOp(globalOp) {}
Type getType() { return type; }
@@ -114,7 +114,7 @@
return elementsAttr;
}
- IREE::Util::GlobalOp getGlobalOp() {
+ IREE::Util::GlobalOpInterface getGlobalOp() {
assert(type == Type::GlobalOp);
return globalOp;
}
@@ -122,7 +122,7 @@
private:
Type type;
ElementsAttr elementsAttr;
- IREE::Util::GlobalOp globalOp;
+ IREE::Util::GlobalOpInterface globalOp;
};
// How to bind results to the original program.
@@ -133,12 +133,12 @@
GlobalOp,
};
- ResultBinding(IREE::Util::GlobalOp globalOp)
+ ResultBinding(IREE::Util::GlobalOpInterface globalOp)
: type(Type::GlobalOp), globalOp(globalOp) {}
Type getType() { return type; }
- IREE::Util::GlobalOp getGlobalOp() {
+ IREE::Util::GlobalOpInterface getGlobalOp() {
assert(type == Type::GlobalOp);
return globalOp;
}
@@ -146,7 +146,7 @@
private:
Type type;
ElementsAttr elementsAttr;
- IREE::Util::GlobalOp globalOp;
+ IREE::Util::GlobalOpInterface globalOp;
};
// Description of a JIT function that we have created for doing some
@@ -222,15 +222,15 @@
Block *entryBlock = &funcOp.getBody().front();
// Find immutable loads.
- for (auto loadOp : funcOp.getOps<IREE::Util::GlobalLoadOp>()) {
- auto globalOp = llvm::dyn_cast_or_null<IREE::Util::GlobalOp>(
+ for (auto loadOp : funcOp.getOps<IREE::Util::GlobalLoadOpInterface>()) {
+ auto globalOp = llvm::dyn_cast_or_null<IREE::Util::GlobalOpInterface>(
sourceSymbolTable.lookup(loadOp.getGlobalAttr().getAttr()));
- if (!globalOp || globalOp.getIsMutable()) {
+ if (!globalOp || globalOp.isGlobalMutable()) {
emitWarning(loadOp.getLoc()) << "skipping consteval initializer: load "
"from mutable globals not supported";
return failure();
}
- Type t = loadOp.getResult().getType();
+ Type t = loadOp.getLoadedGlobalValue().getType();
if (!supportedFeatures.isSupportedAbiType(t)) {
emitWarning(funcOp.getLoc())
<< "skipping consteval initializer: unsupported type for current "
@@ -240,7 +240,7 @@
}
argumentTypes.push_back(t);
BlockArgument entryArg = entryBlock->addArgument(t, loadOp.getLoc());
- loadOp.getResult().replaceAllUsesWith(entryArg);
+ loadOp.getLoadedGlobalValue().replaceAllUsesWith(entryArg);
eraseOps.push_back(loadOp);
desc.argumentBindings.emplace_back(globalOp);
}
@@ -268,15 +268,15 @@
// Find immutable stores, early exiting if not supported.
// The consumers must come after rewrites of the producers above.
- for (auto storeOp : funcOp.getOps<IREE::Util::GlobalStoreOp>()) {
- auto globalOp = llvm::dyn_cast_or_null<IREE::Util::GlobalOp>(
+ for (auto storeOp : funcOp.getOps<IREE::Util::GlobalStoreOpInterface>()) {
+ auto globalOp = llvm::dyn_cast_or_null<IREE::Util::GlobalOpInterface>(
sourceSymbolTable.lookup(storeOp.getGlobalAttr().getAttr()));
- if (!globalOp || globalOp.getIsMutable()) {
+ if (!globalOp || globalOp.isGlobalMutable()) {
emitWarning(storeOp.getLoc()) << "skipping consteval initializer: stor "
"to mutable globals not supported";
return failure();
}
- Type t = storeOp.getValue().getType();
+ Type t = storeOp.getStoredGlobalValue().getType();
if (!supportedFeatures.isSupportedAbiType(t)) {
emitWarning(funcOp.getLoc())
<< "skipping consteval initializer: unsupported type for current "
@@ -285,7 +285,7 @@
return failure();
}
- returns.push_back(storeOp.getValue());
+ returns.push_back(storeOp.getStoredGlobalValue());
returnTypes.push_back(t);
eraseOps.push_back(storeOp);
desc.resultBindings.emplace_back(globalOp);
@@ -417,15 +417,14 @@
break;
case ArgumentBinding::Type::GlobalOp: {
- auto globalValue = arg.getGlobalOp().getInitialValue();
+ auto globalValue = arg.getGlobalOp().getGlobalInitialValue();
if (!globalValue) {
return emitError(jitFunction.loc)
<< "internal error: jit global source initialization order. "
"global "
- << arg.getGlobalOp().getSymName() << " has no value";
+ << arg.getGlobalOp().getGlobalName() << " has no value";
}
- if (failed(
- call.addArgument(arg.getGlobalOp().getLoc(), *globalValue)))
+ if (failed(call.addArgument(arg.getGlobalOp().getLoc(), globalValue)))
return failure();
} break;
}
@@ -443,9 +442,9 @@
TypedAttr attr;
if (failed(call.getResultAsAttr(
resultBinding.getGlobalOp().getLoc(), it.index(),
- resultBinding.getGlobalOp().getType(), attr)))
+ resultBinding.getGlobalOp().getGlobalType(), attr)))
return failure();
- resultBinding.getGlobalOp().setInitialValueAttr(attr);
+ resultBinding.getGlobalOp().setGlobalInitialValue(attr);
break;
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
index 0ad2cd4..07db885 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/ExportBenchmarkFuncs.cpp
@@ -85,8 +85,7 @@
auto barrierOp = initializerBuilder.create<IREE::Util::OptimizationBarrierOp>(
loc, bufferExportOp.getTarget());
// util.global.store
- initializerBuilder.create<IREE::Util::GlobalStoreOp>(
- loc, barrierOp.getResult(0), globalOp.getName());
+ globalOp.createStoreOp(loc, barrierOp.getResult(0), initializerBuilder);
initializerBuilder.create<IREE::Util::ReturnOp>(loc);
return globalOp;
@@ -233,8 +232,9 @@
auto blockBuilder = OpBuilder::atBlockBegin(block);
SmallVector<Value> args;
for (int i = 0, e = entryFuncOp.getNumArguments(); i < e; ++i) {
- args.push_back(blockBuilder.createOrFold<IREE::Util::GlobalLoadOp>(
- loc, dummyInputVariableOps[i]));
+ args.push_back(dummyInputVariableOps[i]
+ .createLoadOp(loc, blockBuilder)
+ .getLoadedGlobalValue());
}
auto callOp = blockBuilder.create<mlir::func::CallOp>(loc, entryFuncOp, args);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index 1fa0d6a..3e7e4b9 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -193,8 +193,7 @@
loc, globalOp.getType(), allocator, queueAffinity, memoryTypes,
bufferUsage, indexSet.get(totalLength));
- initBuilder.create<IREE::Util::GlobalStoreOp>(loc, allocateOp.getResult(),
- globalOp.getNameAttr());
+ globalOp.createStoreOp(loc, allocateOp.getResult(), initBuilder);
initBuilder.create<IREE::Util::ReturnOp>(loc);
return globalOp;
@@ -289,9 +288,8 @@
}
// Push descriptor sets.
- auto buffer =
- funcBuilder.create<IREE::Util::GlobalLoadOp>(loc, bufferGlobalOp)
- .getResult();
+ Value buffer =
+ bufferGlobalOp.createLoadOp(loc, funcBuilder).getLoadedGlobalValue();
int64_t currentSet = -1;
SmallVector<IREE::HAL::DescriptorSetBindingValue> bindingValues;
auto flushSet = [&]() {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp
index d97822e..05c3868 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeDispatchInstrumentation.cpp
@@ -150,8 +150,7 @@
Value buffer = initializerBuilder.create<IREE::Stream::ResourceAllocOp>(
loc, globalOp.getType(), bufferSize,
/*uninitialized=*/true, /*affinity=*/nullptr);
- initializerBuilder.create<IREE::Util::GlobalStoreOp>(loc, buffer,
- globalOp);
+ globalOp.createStoreOp(loc, buffer, initializerBuilder);
initializerBuilder.create<IREE::Util::ReturnOp>(loc);
}
@@ -232,14 +231,15 @@
auto parentBuilder = OpBuilder(executeOp);
// Load the ringbuffer and capture it for use within the execute region.
- auto loadOp =
- parentBuilder.create<IREE::Util::GlobalLoadOp>(loc, globalOp);
+ auto loadedValue =
+ globalOp.createLoadOp(loc, parentBuilder).getLoadedGlobalValue();
Value zero = parentBuilder.create<arith::ConstantIndexOp>(loc, 0);
Value bufferSize =
parentBuilder.create<arith::ConstantOp>(loc, bufferSizeAttr);
- executeOp.getResourceOperandsMutable().append(loadOp.getResult());
+ executeOp.getResourceOperandsMutable().append(loadedValue);
executeOp.getResourceOperandSizesMutable().append(bufferSize);
- auto bufferArg = executeOp.getBody().addArgument(loadOp.getType(), loc);
+ auto bufferArg =
+ executeOp.getBody().addArgument(loadedValue.getType(), loc);
// Walk dispatches and pass them the ringbuffer and their unique ID.
executeOp.walk([&](IREE::Stream::CmdDispatchOp dispatchOp) {
@@ -342,7 +342,7 @@
// Export the device buffer containing the instrument data.
Value buffer =
- queryBuilder.create<IREE::Util::GlobalLoadOp>(loc, globalOp);
+ globalOp.createLoadOp(loc, queryBuilder).getLoadedGlobalValue();
Value bufferSize =
queryBuilder.create<arith::ConstantOp>(loc, bufferSizeAttr);
auto bufferViewType = moduleBuilder.getType<IREE::HAL::BufferViewType>();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index 32e754e..01ba030 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -126,8 +126,7 @@
Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
Value layout = blockBuilder.createOrFold<DescriptorSetLayoutCreateOp>(
loc, layoutType, device, flags, bindingAttrs);
- blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layout,
- globalOp.getName());
+ globalOp.createStoreOp(loc, layout, blockBuilder);
blockBuilder.create<IREE::Util::ReturnOp>(loc);
return globalOp;
@@ -170,10 +169,9 @@
OpBuilder::atBlockEnd(initializerOp.addEntryBlock());
SmallVector<Value> setLayoutValues;
for (auto setLayoutGlobalOp : setLayoutGlobalOps) {
- auto setLayoutValue = blockBuilder.createOrFold<IREE::Util::GlobalLoadOp>(
- loc, DescriptorSetLayoutType::get(loc.getContext()),
- setLayoutGlobalOp.getSymName());
- setLayoutValues.push_back(setLayoutValue);
+ setLayoutValues.push_back(
+ setLayoutGlobalOp.createLoadOp(loc, blockBuilder)
+ .getLoadedGlobalValue());
}
// TODO(multi-device): pass in resolve info to the call and reuse.
Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder);
@@ -181,8 +179,7 @@
loc, layoutType, device,
blockBuilder.getIndexAttr(layoutAttr.getPushConstants()),
setLayoutValues);
- blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, layout,
- globalOp.getName());
+ globalOp.createStoreOp(loc, layout, blockBuilder);
blockBuilder.create<IREE::Util::ReturnOp>(loc);
return globalOp;
@@ -239,9 +236,8 @@
auto pipelineLayoutGlobalOp =
definePipelineLayoutOp(executableOp.getLoc(), exportOp.getLayout());
pipelineLayoutValues.push_back(
- caseBuilder.createOrFold<IREE::Util::GlobalLoadOp>(
- loc, PipelineLayoutType::get(loc.getContext()),
- pipelineLayoutGlobalOp.getSymName()));
+ pipelineLayoutGlobalOp.createLoadOp(loc, caseBuilder)
+ .getLoadedGlobalValue());
}
// Inline constant initializer from the variant.
@@ -278,8 +274,7 @@
defaultBuilder.create<scf::YieldOp>(loc, nullValue);
auto executableValue = switchOp.getResult(0);
- blockBuilder.create<IREE::Util::GlobalStoreOp>(loc, executableValue,
- globalOp.getName());
+ globalOp.createStoreOp(loc, executableValue, blockBuilder);
blockBuilder.create<IREE::Util::ReturnOp>(loc);
}
@@ -323,10 +318,9 @@
OpBuilder builder(lookupOp);
auto globalOp = defineDescriptorSetLayoutOp(
lookupOp.getLoc(), lookupOp.getBindings(), lookupOp.getFlags());
- auto loadOp = builder.create<IREE::Util::GlobalLoadOp>(
- lookupOp.getLoc(), DescriptorSetLayoutType::get(lookupOp.getContext()),
- globalOp.getSymName());
- lookupOp.replaceAllUsesWith(loadOp.getOperation());
+ auto loadedValue = globalOp.createLoadOp(lookupOp.getLoc(), builder)
+ .getLoadedGlobalValue();
+ lookupOp.replaceAllUsesWith(loadedValue);
lookupOp.erase();
}
@@ -334,10 +328,9 @@
OpBuilder builder(lookupOp);
auto globalOp =
definePipelineLayoutOp(lookupOp.getLoc(), lookupOp.getLayout());
- auto loadOp = builder.create<IREE::Util::GlobalLoadOp>(
- lookupOp.getLoc(), PipelineLayoutType::get(lookupOp.getContext()),
- globalOp.getSymName());
- lookupOp.replaceAllUsesWith(loadOp.getOperation());
+ auto loadedValue = globalOp.createLoadOp(lookupOp.getLoc(), builder)
+ .getLoadedGlobalValue();
+ lookupOp.replaceAllUsesWith(loadedValue);
lookupOp.erase();
}
@@ -347,10 +340,9 @@
assert(executableIt != executableCache_.end() &&
"executable must have been cached");
auto globalOp = executableIt->second;
- auto loadOp = builder.create<IREE::Util::GlobalLoadOp>(
- lookupOp.getLoc(), ExecutableType::get(lookupOp.getContext()),
- globalOp.getSymName());
- lookupOp.replaceAllUsesWith(loadOp.getOperation());
+ auto loadedValue = globalOp.createLoadOp(lookupOp.getLoc(), builder)
+ .getLoadedGlobalValue();
+ lookupOp.replaceAllUsesWith(loadedValue);
lookupOp.erase();
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
index c4184b2..9857a33 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -97,21 +97,18 @@
fusedLoc, funcBuilder.getI1Type(), queryType, device,
anyQueryOp.getCategoryAttr(), anyQueryOp.getKeyAttr(),
anyQueryOp.getDefaultValueAttr());
- funcBuilder.create<IREE::Util::GlobalStoreOp>(fusedLoc, queryOp.getOk(),
- okGlobalOp.getName());
- funcBuilder.create<IREE::Util::GlobalStoreOp>(
- fusedLoc, queryOp.getValue(), valueGlobalOp.getName());
+ okGlobalOp.createStoreOp(fusedLoc, queryOp.getOk(), funcBuilder);
+ valueGlobalOp.createStoreOp(fusedLoc, queryOp.getValue(), funcBuilder);
funcBuilder.create<IREE::Util::ReturnOp>(fusedLoc);
for (auto queryOp : queryOps) {
OpBuilder replaceBuilder(queryOp);
- auto okLoadOp = replaceBuilder.create<IREE::Util::GlobalLoadOp>(
- fusedLoc, okGlobalOp.getType(), okGlobalOp.getName());
- auto resultLoadOp = replaceBuilder.create<IREE::Util::GlobalLoadOp>(
- fusedLoc, valueGlobalOp.getType(), valueGlobalOp.getName());
+ auto okLoadOp = okGlobalOp.createLoadOp(fusedLoc, replaceBuilder);
+ auto resultLoadOp =
+ valueGlobalOp.createLoadOp(fusedLoc, replaceBuilder);
queryOp.replaceAllUsesWith(ValueRange{
- okLoadOp.getResult(),
- resultLoadOp.getResult(),
+ okLoadOp.getLoadedGlobalValue(),
+ resultLoadOp.getLoadedGlobalValue(),
});
queryOp.erase();
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
index 1db93f9..c29fa9a 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/UtilToStream/Patterns.cpp
@@ -137,10 +137,9 @@
initialValueSize = rewriter.create<IREE::Stream::ResourceSizeOp>(
globalOp.getLoc(), indexType, initialValue);
}
- rewriter.create<IREE::Util::GlobalStoreOp>(
- globalOp.getLoc(), initialValue, resourceOp.getSymName());
- rewriter.create<IREE::Util::GlobalStoreOp>(
- globalOp.getLoc(), initialValueSize, resourceSizeOp.getSymName());
+ resourceOp.createStoreOp(globalOp.getLoc(), initialValue, rewriter);
+ resourceSizeOp.createStoreOp(globalOp.getLoc(), initialValueSize,
+ rewriter);
rewriter.create<IREE::Util::ReturnOp>(globalOp.getLoc());
}
@@ -162,7 +161,7 @@
// Only apply to expanded types (tensors/etc).
if (!isExpandedType(loadOp.getType()))
return failure();
- auto &expandedGlobal = expansionState->globalMap[adaptor.getGlobal()];
+ auto &expandedGlobal = this->expansionState->globalMap[adaptor.getGlobal()];
// Insert a load/transfer to the unknown resource lifetime.
auto unknownType = IREE::Stream::ResourceType::get(rewriter.getContext());
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp
index 30fc749..54b949d 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PropagateTimepoints.cpp
@@ -298,43 +298,39 @@
// %t = util.global.load @foo : !stream.timepoint
// %0 = util.global.load @foo : !stream.resource
// %1 = stream.timepoint.await %t, %0
-static void expandGlobalLoadOp(IREE::Util::GlobalLoadOp op,
+static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op,
ExpandedGlobalMap &globalMap,
IRMapping &resourceTimepointMap) {
if (!usesResources(op))
return;
OpBuilder builder(op);
- auto &expandedGlobal = globalMap[op.getGlobal()];
- auto timepoint =
- builder
- .create<IREE::Util::GlobalLoadOp>(
- op.getLoc(), IREE::Stream::TimepointType::get(op.getContext()),
- expandedGlobal.timepointOp.getName())
- .getResult();
- resourceTimepointMap.map(op.getResult(), timepoint);
+ auto &expandedGlobal = globalMap[op.getGlobalName()];
+ auto timepoint = expandedGlobal.timepointOp.createLoadOp(op.getLoc(), builder)
+ .getLoadedGlobalValue();
+ resourceTimepointMap.map(op.getLoadedGlobalValue(), timepoint);
// HACK: queryValueSize may insert other ops that we don't want to replace.
// TODO(benvanik): carry the size so we don't need to guess here.
SmallPtrSet<Operation *, 2> replacementExceptions;
builder.setInsertionPointAfter(op);
auto resultSize = IREE::Util::SizeAwareTypeInterface::queryValueSize(
- op.getLoc(), op.getResult(), builder);
+ op.getLoc(), op.getLoadedGlobalValue(), builder);
if (resultSize) {
replacementExceptions.insert(resultSize.getDefiningOp());
} else {
- auto sizeOp = builder.create<IREE::Stream::ResourceSizeOp>(op.getLoc(),
- op.getResult());
+ auto sizeOp = builder.create<IREE::Stream::ResourceSizeOp>(
+ op.getLoc(), op.getLoadedGlobalValue());
replacementExceptions.insert(sizeOp);
resultSize = sizeOp.getResult();
}
assert(resultSize && "need to be able to get a size");
auto awaitOp = builder.create<IREE::Stream::TimepointAwaitOp>(
- op.getLoc(), op.getResult(), resultSize, timepoint);
+ op.getLoc(), op.getLoadedGlobalValue(), resultSize, timepoint);
replacementExceptions.insert(awaitOp);
- op.getResult().replaceAllUsesExcept(awaitOp.getResults().front(),
- replacementExceptions);
+ op.getLoadedGlobalValue().replaceAllUsesExcept(awaitOp.getResults().front(),
+ replacementExceptions);
}
// Moves awaits from global stores to loads.
@@ -346,19 +342,18 @@
// ->
// util.global.store %t, @foo_timepoint : !stream.timepoint
// util.global.store %0, @foo : !stream.resource
-static void expandGlobalStoreOp(IREE::Util::GlobalStoreOp op,
+static void expandGlobalStoreOp(IREE::Util::GlobalStoreOpInterface op,
ExpandedGlobalMap &globalMap,
IRMapping &resourceTimepointMap) {
if (!usesResources(op))
return;
OpBuilder builder(op);
- auto timepointOperand = consumeTimepoint(op.getLoc(), op.getValue(),
- resourceTimepointMap, builder);
- auto &expandedGlobal = globalMap[op.getGlobal()];
- builder.create<IREE::Util::GlobalStoreOp>(
- op.getLoc(), timepointOperand.first,
- expandedGlobal.timepointOp.getName());
- op.getValueMutable().assign(timepointOperand.second);
+ auto timepointOperand = consumeTimepoint(
+ op.getLoc(), op.getStoredGlobalValue(), resourceTimepointMap, builder);
+ auto &expandedGlobal = globalMap[op.getGlobalName()];
+ expandedGlobal.timepointOp.createStoreOp(op.getLoc(), timepointOperand.first,
+ builder);
+ op.setStoredGlobalValue(timepointOperand.second);
}
static void expandInitializerOp(IREE::Util::InitializerOp op,
@@ -605,9 +600,9 @@
// awaits.
static void expandTimepoints(Operation *op, ExpandedGlobalMap &globalMap,
IRMapping &resourceTimepointMap) {
- if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOp>(op)) {
+ if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOpInterface>(op)) {
expandGlobalLoadOp(loadOp, globalMap, resourceTimepointMap);
- } else if (auto storeOp = dyn_cast<IREE::Util::GlobalStoreOp>(op)) {
+ } else if (auto storeOp = dyn_cast<IREE::Util::GlobalStoreOpInterface>(op)) {
expandGlobalStoreOp(storeOp, globalMap, resourceTimepointMap);
} else if (auto initializerOp = dyn_cast<IREE::Util::InitializerOp>(op)) {
expandInitializerOp(initializerOp, globalMap, resourceTimepointMap);
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
index 7129858..1e6fada 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.cpp
@@ -67,12 +67,12 @@
if (!isLegalConstExprRootType(info->op.getGlobalType()))
return;
for (auto *use : info->uses) {
- auto loadOp = llvm::dyn_cast<GlobalLoadOp>(use);
+ auto loadOp = llvm::dyn_cast<GlobalLoadOpInterface>(use);
if (!loadOp)
continue;
if (!isHoistableToRootOp(rootOp, loadOp))
continue;
- constantRoots[loadOp.getResult()] = loadOp;
+ constantRoots[loadOp.getLoadedGlobalValue()] = loadOp;
}
});
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/Patterns.cpp
index fb093df..7821535 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/MemRefToUtil/Patterns.cpp
@@ -145,8 +145,8 @@
auto constantOp = initializerBuilder.create<IREE::Util::BufferConstantOp>(
globalOp.getLoc(), /*name=*/nullptr, globalOp.getInitialValueAttr(),
alignmentAttr, /*mimeType=*/nullptr);
- initializerBuilder.create<IREE::Util::GlobalStoreOp>(
- globalOp.getLoc(), constantOp.getResult(), newOp.getName());
+ newOp.createStoreOp(globalOp.getLoc(), constantOp.getResult(),
+ initializerBuilder);
initializerBuilder.create<IREE::Util::ReturnOp>(globalOp.getLoc());
return success();
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilExternalModels.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilExternalModels.cpp
index 9add0f0..5e39aad 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilExternalModels.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilExternalModels.cpp
@@ -126,6 +126,35 @@
op->removeAttr("noinline");
}
}
+
+ IREE::Util::GlobalLoadOpInterface createLoadOp(Operation *op, Location loc,
+ OpBuilder &builder) const {
+ auto globalOp = cast<ml_program::GlobalOp>(op);
+ if (globalOp.getIsMutable()) {
+ return cast<IREE::Util::GlobalLoadOpInterface>(
+ builder
+ .create<ml_program::GlobalLoadOp>(
+ loc, globalOp.getType(), FlatSymbolRefAttr::get(globalOp))
+ .getOperation());
+ } else {
+ return cast<IREE::Util::GlobalLoadOpInterface>(
+ builder
+ .create<ml_program::GlobalLoadConstOp>(
+ loc, globalOp.getType(), FlatSymbolRefAttr::get(globalOp))
+ .getOperation());
+ }
+ }
+
+ IREE::Util::GlobalStoreOpInterface createStoreOp(Operation *op, Location loc,
+ Value value,
+ OpBuilder &builder) const {
+ auto globalOp = cast<ml_program::GlobalOp>(op);
+ return cast<IREE::Util::GlobalStoreOpInterface>(
+ builder
+ .create<ml_program::GlobalStoreOp>(
+ loc, FlatSymbolRefAttr::get(globalOp), value)
+ .getOperation());
+ }
};
} // namespace
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
index e91d4f5..5f89f2c 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td
@@ -296,6 +296,22 @@
}
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Creates an operation loading the value of the global.
+ }],
+ /*retTy=*/"IREE::Util::GlobalLoadOpInterface",
+ /*methodName=*/"createLoadOp",
+ /*args=*/(ins "Location":$loc, "OpBuilder &":$builder)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Creates an operation storing the value of the global.
+ }],
+ /*retTy=*/"IREE::Util::GlobalStoreOpInterface",
+ /*methodName=*/"createStoreOp",
+ /*args=*/(ins "Location":$loc, "Value":$value, "OpBuilder &":$builder)
+ >,
];
let verify = [{
@@ -349,6 +365,35 @@
return $_op->getResult(0);
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if the global whose address is taken is immutable outside
+ of initializers.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isGlobalImmutable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return $_op.getIsImmutable();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Sets whether the global is immutable outside of initializers.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"setGlobalImmutable",
+ /*args=*/(ins "bool":$value),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ if (value) {
+ $_op.setIsImmutableAttr(UnitAttr::get($_op.getContext()));
+ } else {
+ $_op.removeIsImmutableAttr();
+ }
+ }]
+ >,
];
}
@@ -386,6 +431,34 @@
return $_op->getResult(0);
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if the global loaded is immutable outside of initializers.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isGlobalImmutable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return $_op.getIsImmutable();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Sets whether the global is immutable outside of initializers.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"setGlobalImmutable",
+ /*args=*/(ins "bool":$value),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ if (value) {
+ $_op.setIsImmutableAttr(UnitAttr::get($_op.getContext()));
+ } else {
+ $_op.removeIsImmutableAttr();
+ }
+ }]
+ >,
];
}
@@ -423,6 +496,18 @@
return $_op.getValue();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Sets the value stored by the operation.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"setStoredGlobalValue",
+ /*args=*/(ins "Value":$value),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return $_op.getValueMutable().assign(value);
+ }]
+ >,
];
}
@@ -485,6 +570,18 @@
return $_op.getValue();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Sets the value stored by the operation.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"setStoredGlobalValue",
+ /*args=*/(ins "Value":$value),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return $_op.getValueMutable().assign(value);
+ }]
+ >,
];
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index 13863b8..e6e694c 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -591,17 +591,16 @@
namespace {
/// Turns util.global.address -> util.global.load.indirect into a direct load.
-class PropagateGlobalLoadAddress
- : public OpRewritePattern<GlobalLoadIndirectOp> {
- using OpRewritePattern::OpRewritePattern;
-
-public:
- LogicalResult matchAndRewrite(GlobalLoadIndirectOp op,
+template <typename IndirectOpT, typename DirectOpT>
+struct PropagateGlobalLoadAddress : public OpRewritePattern<IndirectOpT> {
+ using OpRewritePattern<IndirectOpT>::OpRewritePattern;
+ LogicalResult matchAndRewrite(IndirectOpT op,
PatternRewriter &rewriter) const override {
if (auto addressOp = dyn_cast_or_null<GlobalAddressOpInterface>(
op.getGlobal().getDefiningOp())) {
- rewriter.replaceOpWithNewOp<GlobalLoadOp>(op, op.getResult().getType(),
- addressOp.getGlobalAttr());
+ rewriter.replaceOpWithNewOp<DirectOpT>(
+ op, op.getResult().getType(), addressOp.getGlobalAttr(),
+ addressOp.isGlobalImmutable() ? rewriter.getUnitAttr() : UnitAttr{});
return success();
}
return failure();
@@ -612,7 +611,8 @@
void GlobalLoadIndirectOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.insert<PropagateGlobalLoadAddress>(context);
+ results.insert<PropagateGlobalLoadAddress<IREE::Util::GlobalLoadIndirectOp,
+ IREE::Util::GlobalLoadOp>>(context);
}
namespace {
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 8b32bbb..75ea2c3 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -1525,6 +1525,19 @@
build(builder, result, name, isMutable, type, std::nullopt, attrs);
}
+IREE::Util::GlobalLoadOpInterface GlobalOp::createLoadOp(Location loc,
+ OpBuilder &builder) {
+ // TODO(benvanik): create with the immutable flag if the global is immutable.
+ // Today we avoid this and let analysis add the immutable flag when safe
+ // (not in initializers/etc).
+ return builder.create<IREE::Util::GlobalLoadOp>(loc, getType(), getSymName());
+}
+
+IREE::Util::GlobalStoreOpInterface
+GlobalOp::createStoreOp(Location loc, Value value, OpBuilder &builder) {
+ return builder.create<IREE::Util::GlobalStoreOp>(loc, value, getSymName());
+}
+
void GlobalAddressOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), Twine("ptr_" + getGlobal()).str());
@@ -1545,21 +1558,17 @@
void GlobalLoadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- // HACK: works around the lack of symbol side effects in mlir by only saying
- // we have a side-effect if the variable we are loading is mutable.
- auto globalOp =
- SymbolTable::lookupNearestSymbolFrom<GlobalOp>(*this, getGlobalAttr());
- assert(globalOp);
- if (globalOp.getIsMutable()) {
+ // HACK: mlir doesn't have symbol side effects so we have to mark as a global
+ // read if not immutable and not in an initializer.
+ if (!isGlobalImmutable())
effects.emplace_back(MemoryEffects::Read::get());
- }
}
-LogicalResult GlobalLoadIndirectOp::verify() {
- Operation *op = getOperation();
+LogicalResult
+verifyGlobalLoadIndirectOp(IREE::Util::GlobalLoadIndirectOpInterface op) {
auto globalType =
- cast<IREE::Util::PtrType>(getGlobal().getType()).getTargetType();
- auto loadType = getResult().getType();
+ cast<IREE::Util::PtrType>(op.getGlobal().getType()).getTargetType();
+ auto loadType = op.getLoadedGlobalValue().getType();
if (!isGlobalTypeCompatible(globalType, loadType)) {
return op->emitOpError() << "global type mismatch; global pointer is "
<< globalType << " but load is " << loadType;
@@ -1567,6 +1576,10 @@
return success();
}
+LogicalResult GlobalLoadIndirectOp::verify() {
+ return verifyGlobalLoadIndirectOp(*this);
+}
+
void GlobalStoreOp::build(OpBuilder &builder, OperationState &state,
Value value, IREE::Util::GlobalOpInterface globalOp,
ArrayRef<NamedAttribute> attrs) {
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index de59620..6113201 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -789,6 +789,11 @@
)>,
];
+ let extraClassDeclaration = [{
+ IREE::Util::GlobalLoadOpInterface createLoadOp(Location loc, OpBuilder &builder);
+ IREE::Util::GlobalStoreOpInterface createStoreOp(Location loc, Value value, OpBuilder &builder);
+ }];
+
let hasCanonicalizer = 1;
}
@@ -804,13 +809,15 @@
}];
let arguments = (ins
- Util_GlobalRefAttr:$global
+ Util_GlobalRefAttr:$global,
+ UnitAttr:$is_immutable
);
let results = (outs
Util_AnyGlobalPtr:$result
);
let assemblyFormat = [{
+ (`immutable` $is_immutable^)?
$global attr-dict `:` qualified(type($result))
}];
@@ -834,13 +841,15 @@
}];
let arguments = (ins
- Arg<Util_GlobalRefAttr, "", []>:$global
+ Arg<Util_GlobalRefAttr, "", []>:$global,
+ UnitAttr:$is_immutable
);
let results = (outs
AnyType:$result
);
let assemblyFormat = [{
+ (`immutable` $is_immutable^)?
$global attr-dict `:` type($result)
}];
@@ -867,13 +876,15 @@
}];
let arguments = (ins
- Arg<Util_AnyGlobalPtr, "", []>:$global
+ Arg<Util_AnyGlobalPtr, "", []>:$global,
+ UnitAttr:$is_immutable
);
let results = (outs
AnyType:$result
);
let assemblyFormat = [{
+ (`immutable` $is_immutable^)?
$global attr-dict `:` qualified(type($global)) `->` type($result)
}];
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
index f3dd2ff..d9c4a92 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
@@ -287,6 +287,10 @@
}
// TODO(benvanik): allow type conversion here? probably better on the indirect
// access ops instead as it's then easier to fold the conversion.
+ if (addressOp.isGlobalImmutable() && globalOp.isGlobalMutable()) {
+ return addressOp->emitOpError()
+ << "is marked as immutable but the global is mutable";
+ }
return success();
}
@@ -303,6 +307,10 @@
<< "global type mismatch; global " << globalOp.getGlobalName()
<< " is " << globalOp.getGlobalType() << " but load is " << loadType;
}
+ if (loadOp.isGlobalImmutable() && globalOp.isGlobalMutable()) {
+ return loadOp->emitOpError()
+ << "is marked as immutable but the global is mutable";
+ }
return success();
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
index 52fa07e..a8b7661 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
@@ -250,10 +250,19 @@
return GlobalAction::PRESERVE;
if (!global.storeOps.empty())
return GlobalAction::PRESERVE;
- if (!global.op.isGlobalMutable())
- return GlobalAction::PRESERVE;
+ bool didChangeAny = global.op.isGlobalMutable() != false;
global.op.setGlobalMutable(false);
- return GlobalAction::UPDATE;
+ for (auto loadOp : global.loadOps) {
+ // NOTE: we don't set immutable on loads in initializers today.
+ // We should be able to, though, with a bit better analysis.
+ if (!loadOp->getParentOfType<IREE::Util::InitializerOpInterface>()) {
+ if (!loadOp.isGlobalImmutable()) {
+ loadOp.setGlobalImmutable(true);
+ didChangeAny = true;
+ }
+ }
+ }
+ return didChangeAny ? GlobalAction::UPDATE : GlobalAction::PRESERVE;
});
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp
index b00caba..023031a 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OutlineConstants.cpp
@@ -101,14 +101,12 @@
auto globalOp = pair.second;
OpBuilder builder(moduleOp.getContext());
builder.setInsertionPoint(originalOp);
- auto loadOp = builder.create<IREE::Util::GlobalLoadOp>(
- originalOp->getLoc(), globalOp.getType(),
- SymbolRefAttr::get(globalOp));
+ auto loadOp = globalOp.createLoadOp(originalOp->getLoc(), builder);
Value replacement;
if (auto constantOp = dyn_cast<arith::ConstantOp>(originalOp)) {
// Directly replace constant with global constant value.
- replacement = loadOp.getResult();
+ replacement = loadOp.getLoadedGlobalValue();
} else {
assert(false && "unhandled constant op type");
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp
index 228be14..b5d0a20 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp
@@ -321,37 +321,31 @@
// %l = util.global.load @foo_length : index
// %1 = stream.resource.subview %0[%o] :
// !stream.resource<*>{%s} -> !stream.resource<*>{%l}
-static void expandGlobalLoadOp(IREE::Util::GlobalLoadOp op,
+static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op,
ExpandedGlobalMap &globalMap, IndexSet &indexSet,
SubrangeMap &subrangeMap) {
if (!usesResources(op))
return;
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
- auto indexType = builder.getIndexType();
- auto &expandedGlobal = globalMap[op.getGlobal()];
+ auto &expandedGlobal = globalMap[op.getGlobalName()];
Subrange subrange;
- subrange.resource = op.getResult();
+ subrange.resource = op.getLoadedGlobalValue();
subrange.resourceSize =
- builder
- .create<IREE::Util::GlobalLoadOp>(
- op.getLoc(), indexType, expandedGlobal.resourceSizeOp.getName())
- .getResult();
+ expandedGlobal.resourceSizeOp.createLoadOp(op.getLoc(), builder)
+ .getLoadedGlobalValue();
subrange.subrangeOffset =
- builder
- .create<IREE::Util::GlobalLoadOp>(
- op.getLoc(), indexType, expandedGlobal.subrangeOffsetOp.getName())
- .getResult();
+ expandedGlobal.subrangeOffsetOp.createLoadOp(op.getLoc(), builder)
+ .getLoadedGlobalValue();
subrange.subrangeLength =
- builder
- .create<IREE::Util::GlobalLoadOp>(
- op.getLoc(), indexType, expandedGlobal.subrangeLengthOp.getName())
- .getResult();
- subrangeMap[op.getResult()] = subrange;
+ expandedGlobal.subrangeLengthOp.createLoadOp(op.getLoc(), builder)
+ .getLoadedGlobalValue();
+ subrangeMap[op.getLoadedGlobalValue()] = subrange;
auto newSubrange = subrange.getResourceType().createSubrangeOp(
op.getLoc(), subrange.resource, subrange.resourceSize,
subrange.subrangeOffset, subrange.subrangeLength, builder);
- op.getResult().replaceAllUsesExcept(newSubrange, newSubrange.getDefiningOp());
+ op.getLoadedGlobalValue().replaceAllUsesExcept(newSubrange,
+ newSubrange.getDefiningOp());
}
// Moves resource subranges from global stores to loads.
@@ -366,16 +360,16 @@
// util.global.store %s, @foo_size : index
// util.global.store %o, @foo_offset : index
// util.global.store %l, @foo_length : index
-static void expandGlobalStoreOp(IREE::Util::GlobalStoreOp op,
+static void expandGlobalStoreOp(IREE::Util::GlobalStoreOpInterface op,
ExpandedGlobalMap &globalMap,
IndexSet &indexSet, SubrangeMap &subrangeMap) {
if (!usesResources(op))
return;
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
- auto subrange = consumeSubrange(op.getLoc(), op.getValue(), subrangeMap,
- indexSet, builder);
- auto &expandedGlobal = globalMap[op.getGlobal()];
+ auto subrange = consumeSubrange(op.getLoc(), op.getStoredGlobalValue(),
+ subrangeMap, indexSet, builder);
+ auto &expandedGlobal = globalMap[op.getGlobalName()];
builder.create<IREE::Util::GlobalStoreOp>(
op.getLoc(), subrange.resource, expandedGlobal.resourceOp.getName());
builder.create<IREE::Util::GlobalStoreOp>(
@@ -585,9 +579,9 @@
return updateSubrangeOp(subrangeOp, indexSet, subrangeMap);
}
- if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOp>(op)) {
+ if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOpInterface>(op)) {
return expandGlobalLoadOp(loadOp, globalMap, indexSet, subrangeMap);
- } else if (auto storeOp = dyn_cast<IREE::Util::GlobalStoreOp>(op)) {
+ } else if (auto storeOp = dyn_cast<IREE::Util::GlobalStoreOpInterface>(op)) {
return expandGlobalStoreOp(storeOp, globalMap, indexSet, subrangeMap);
} else if (auto initializerOp = dyn_cast<IREE::Util::InitializerOp>(op)) {
return expandInitializerOp(initializerOp, globalMap, indexSet, subrangeMap);
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/fold_globals.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/fold_globals.mlir
index ce0226e..6eaeb1b 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/fold_globals.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/fold_globals.mlir
@@ -39,7 +39,7 @@
// CHECK-NOT: util.global private mutable @chained1 : index
util.global private mutable @chained1 : index
func.func @foo() -> index {
- // CHECK: %[[VALUE:.+]] = util.global.load @chained0 : index
+ // CHECK: %[[VALUE:.+]] = util.global.load immutable @chained0 : index
%0 = util.global.load @chained0 : index
// CHECK-NOT: util.global.store
util.global.store %0, @chained1 : index
@@ -135,9 +135,9 @@
// CHECK-NOT: util.global private @dupeCst1
util.global private @dupeCst1 {inlining_policy = #util.inline.never} = 5 : index
func.func @foo() -> (index, index) {
- // CHECK-DAG: %[[VALUE0:.+]] = util.global.load @dupeCst0
+ // CHECK-DAG: %[[VALUE0:.+]] = util.global.load immutable @dupeCst0
%0 = util.global.load @dupeCst0 : index
- // CHECK-DAG: %[[VALUE1:.+]] = util.global.load @dupeCst0
+ // CHECK-DAG: %[[VALUE1:.+]] = util.global.load immutable @dupeCst0
%1 = util.global.load @dupeCst1 : index
// CHECK: return %[[VALUE0]], %[[VALUE1]]
return %0, %1 : index, index
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp
index 2aacef3..39dc7c0 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertGlobalOps.cpp
@@ -17,7 +17,6 @@
struct InitializerOpConversion
: public OpConversionPattern<IREE::Util::InitializerOp> {
using OpConversionPattern::OpConversionPattern;
-
LogicalResult
matchAndRewrite(IREE::Util::InitializerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -40,7 +39,6 @@
struct ReturnOpConversion : public OpConversionPattern<IREE::Util::ReturnOp> {
using OpConversionPattern::OpConversionPattern;
-
LogicalResult
matchAndRewrite(IREE::Util::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -49,11 +47,10 @@
}
};
-class GlobalOpConversion : public OpConversionPattern<IREE::Util::GlobalOp> {
-public:
+struct GlobalOpConversion : public OpConversionPattern<IREE::Util::GlobalOp> {
+ TypeConverter &typeConverter;
GlobalOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern(context), typeConverter(typeConverter) {}
-
LogicalResult
matchAndRewrite(IREE::Util::GlobalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -114,72 +111,61 @@
return success();
}
-
-private:
- TypeConverter &typeConverter;
};
-class GlobalAddressOpConversion
+struct GlobalAddressOpConversion
: public OpConversionPattern<IREE::Util::GlobalAddressOp> {
-public:
+ TypeConverter &typeConverter;
GlobalAddressOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern(context), typeConverter(typeConverter) {}
-
LogicalResult
matchAndRewrite(IREE::Util::GlobalAddressOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalAddressOp>(
- op, typeConverter.convertType(op.getType()), op.getGlobal());
+ op, typeConverter.convertType(op.getType()), op.getGlobalAttr(),
+ op.getIsImmutableAttr());
return success();
}
-
-private:
- TypeConverter &typeConverter;
};
-class GlobalLoadOpConversion
+struct GlobalLoadOpConversion
: public OpConversionPattern<IREE::Util::GlobalLoadOp> {
-public:
+ TypeConverter &typeConverter;
GlobalLoadOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern(context), typeConverter(typeConverter) {}
-
LogicalResult
matchAndRewrite(IREE::Util::GlobalLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = op.getType();
auto convertedType = typeConverter.convertType(operandType);
if (IREE::VM::RefType::isCompatible(operandType)) {
- rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadRefOp>(op, convertedType,
- op.getGlobal());
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadRefOp>(
+ op, convertedType, op.getGlobalAttr(), adaptor.getIsImmutableAttr());
} else if (convertedType.isInteger(32)) {
- rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadI32Op>(op, convertedType,
- op.getGlobal());
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadI32Op>(
+ op, convertedType, op.getGlobalAttr(), adaptor.getIsImmutableAttr());
} else if (convertedType.isInteger(64)) {
- rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadI64Op>(op, convertedType,
- op.getGlobal());
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadI64Op>(
+ op, convertedType, op.getGlobalAttr(), adaptor.getIsImmutableAttr());
} else if (convertedType.isF32()) {
- rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadF32Op>(op, convertedType,
- op.getGlobal());
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadF32Op>(
+ op, convertedType, op.getGlobalAttr(), adaptor.getIsImmutableAttr());
} else if (convertedType.isF64()) {
- rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadF64Op>(op, convertedType,
- op.getGlobal());
+ rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadF64Op>(
+ op, convertedType, op.getGlobalAttr(), adaptor.getIsImmutableAttr());
} else {
return rewriter.notifyMatchFailure(op, "unhandled global type");
}
return success();
}
-
-private:
- TypeConverter &typeConverter;
};
-class GlobalLoadIndirectOpConversion
+struct GlobalLoadIndirectOpConversion
: public OpConversionPattern<IREE::Util::GlobalLoadIndirectOp> {
-public:
+ TypeConverter &typeConverter;
GlobalLoadIndirectOpConversion(MLIRContext *context,
TypeConverter &typeConverter)
: OpConversionPattern(context), typeConverter(typeConverter) {}
-
LogicalResult
matchAndRewrite(IREE::Util::GlobalLoadIndirectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -187,35 +173,30 @@
auto convertedType = typeConverter.convertType(operandType);
if (IREE::VM::RefType::isCompatible(operandType)) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectRefOp>(
- op, convertedType, adaptor.getGlobal());
+ op, convertedType, adaptor.getGlobal(), adaptor.getIsImmutableAttr());
} else if (convertedType.isInteger(32)) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectI32Op>(
- op, convertedType, adaptor.getGlobal());
+ op, convertedType, adaptor.getGlobal(), adaptor.getIsImmutableAttr());
} else if (convertedType.isInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectI64Op>(
- op, convertedType, adaptor.getGlobal());
+ op, convertedType, adaptor.getGlobal(), adaptor.getIsImmutableAttr());
} else if (convertedType.isF32()) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectF32Op>(
- op, convertedType, adaptor.getGlobal());
+ op, convertedType, adaptor.getGlobal(), adaptor.getIsImmutableAttr());
} else if (convertedType.isF64()) {
rewriter.replaceOpWithNewOp<IREE::VM::GlobalLoadIndirectF64Op>(
- op, convertedType, adaptor.getGlobal());
+ op, convertedType, adaptor.getGlobal(), adaptor.getIsImmutableAttr());
} else {
return rewriter.notifyMatchFailure(op, "unhandled global type");
}
return success();
}
-
-private:
- TypeConverter &typeConverter;
};
-class GlobalStoreOpConversion
+struct GlobalStoreOpConversion
: public OpConversionPattern<IREE::Util::GlobalStoreOp> {
-public:
GlobalStoreOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern(context) {}
-
LogicalResult
matchAndRewrite(IREE::Util::GlobalStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -242,13 +223,11 @@
}
};
-class GlobalStoreIndirectOpConversion
+struct GlobalStoreIndirectOpConversion
: public OpConversionPattern<IREE::Util::GlobalStoreIndirectOp> {
-public:
GlobalStoreIndirectOpConversion(MLIRContext *context,
TypeConverter &typeConverter)
: OpConversionPattern(context) {}
-
LogicalResult
matchAndRewrite(IREE::Util::GlobalStoreIndirectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 0d36737..01d5dac 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -202,8 +202,9 @@
PatternRewriter &rewriter) const override {
if (auto addressOp = dyn_cast_or_null<IREE::Util::GlobalAddressOpInterface>(
op.getGlobal().getDefiningOp())) {
- rewriter.replaceOpWithNewOp<DIRECT>(op, op.getValue().getType(),
- addressOp.getGlobalAttr());
+ rewriter.replaceOpWithNewOp<DIRECT>(
+ op, op.getValue().getType(), addressOp.getGlobalAttr(),
+ addressOp.isGlobalImmutable() ? rewriter.getUnitAttr() : UnitAttr{});
return success();
}
return failure();
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 2520685..7a92598 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -419,6 +419,60 @@
// Globals
//===----------------------------------------------------------------------===//
+IREE::Util::GlobalLoadOpInterface
+GlobalI32Op::createLoadOp(Location loc, OpBuilder &builder) {
+ return builder.create<IREE::VM::GlobalLoadI32Op>(loc, builder.getI32Type(),
+ getGlobalName());
+}
+IREE::Util::GlobalStoreOpInterface
+GlobalI32Op::createStoreOp(Location loc, Value value, OpBuilder &builder) {
+ return builder.create<IREE::VM::GlobalStoreI32Op>(loc, value,
+ getGlobalName());
+}
+
+IREE::Util::GlobalLoadOpInterface
+GlobalI64Op::createLoadOp(Location loc, OpBuilder &builder) {
+ return builder.create<IREE::VM::GlobalLoadI64Op>(loc, builder.getI64Type(),
+ getGlobalName());
+}
+IREE::Util::GlobalStoreOpInterface
+GlobalI64Op::createStoreOp(Location loc, Value value, OpBuilder &builder) {
+ return builder.create<IREE::VM::GlobalStoreI64Op>(loc, value,
+ getGlobalName());
+}
+
+IREE::Util::GlobalLoadOpInterface
+GlobalF32Op::createLoadOp(Location loc, OpBuilder &builder) {
+ return builder.create<IREE::VM::GlobalLoadF32Op>(loc, builder.getF32Type(),
+ getGlobalName());
+}
+IREE::Util::GlobalStoreOpInterface
+GlobalF32Op::createStoreOp(Location loc, Value value, OpBuilder &builder) {
+ return builder.create<IREE::VM::GlobalStoreF32Op>(loc, value,
+ getGlobalName());
+}
+
+IREE::Util::GlobalLoadOpInterface
+GlobalF64Op::createLoadOp(Location loc, OpBuilder &builder) {
+ return builder.create<IREE::VM::GlobalLoadF64Op>(loc, builder.getF64Type(),
+ getGlobalName());
+}
+IREE::Util::GlobalStoreOpInterface
+GlobalF64Op::createStoreOp(Location loc, Value value, OpBuilder &builder) {
+ return builder.create<IREE::VM::GlobalStoreF64Op>(loc, value,
+ getGlobalName());
+}
+
+IREE::Util::GlobalLoadOpInterface
+GlobalRefOp::createLoadOp(Location loc, OpBuilder &builder) {
+ return builder.create<IREE::VM::GlobalLoadRefOp>(loc, getType(),
+ getSymName());
+}
+IREE::Util::GlobalStoreOpInterface
+GlobalRefOp::createStoreOp(Location loc, Value value, OpBuilder &builder) {
+ return builder.create<IREE::VM::GlobalStoreRefOp>(loc, value, getSymName());
+}
+
template <typename T>
static void addMemoryEffectsForGlobal(
Operation *op, mlir::FlatSymbolRefAttr global,
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
index 4c8bdf5..7d3851d 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -379,7 +379,10 @@
IsolatedFromAbove,
HasParent<"IREE::VM::ModuleOp">,
Symbol,
- Util_GlobalOpInterface,
+ DeclareOpInterfaceMethods<Util_GlobalOpInterface, [
+ "createLoadOp",
+ "createStoreOp",
+ ]>
])> {
let arguments = (ins
OptionalAttr<StrAttr>:$sym_visibility,
@@ -545,13 +548,15 @@
}];
let arguments = (ins
- VM_GlobalRefAttr:$global
+ VM_GlobalRefAttr:$global,
+ UnitAttr:$is_immutable
);
let results = (outs
AnyTypeOf<[VM_Ptr, Util_AnyPtrType]>:$result
);
let assemblyFormat = [{
+ (`immutable` $is_immutable^)?
$global attr-dict `:` type($result)
}];
@@ -571,13 +576,15 @@
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
])> {
let arguments = (ins
- VM_GlobalRefAttr:$global
+ VM_GlobalRefAttr:$global,
+ UnitAttr:$is_immutable
);
let results = (outs
type:$value
);
let assemblyFormat = [{
+ (`immutable` $is_immutable^)?
$global attr-dict `:` type($value)
}];
@@ -645,13 +652,15 @@
Util_GlobalLoadIndirectOpInterface,
])> {
let arguments = (ins
- AnyTypeOf<[VM_Ptr, Util_PtrOf<type>]>:$global
+ AnyTypeOf<[VM_Ptr, Util_PtrOf<type>]>:$global,
+ UnitAttr:$is_immutable
);
let results = (outs
type:$value
);
let assemblyFormat = [{
+ (`immutable` $is_immutable^)?
$global attr-dict `:` type($global) `->` type($value)
}];
}
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp
index 9a44d82..069fc37 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/MaterializeConstants.cpp
@@ -88,9 +88,9 @@
valueGlobalOp.setPrivate();
valueGlobalOps.push_back(valueGlobalOp);
for (auto loadOp : loadOps) {
- auto newOp = OpBuilder(loadOp).create<IREE::Util::GlobalLoadOp>(
- loadOp.getLoc(), valueGlobalOp);
- loadOp.replaceAllUsesWith(newOp.getResult());
+ OpBuilder builder(loadOp);
+ auto newOp = valueGlobalOp.createLoadOp(loadOp.getLoc(), builder);
+ loadOp.replaceAllUsesWith(newOp.getLoadedGlobalValue());
loadOp.erase();
}
}
@@ -111,8 +111,9 @@
buffer.getLoc(), sizeof(uint32_t), 32);
for (auto [ordinalGlobalOp, valueGlobalOp] :
llvm::zip_equal(ordinalGlobalOps, valueGlobalOps)) {
- Value loadedOrdinal = setterBuilder.create<IREE::Util::GlobalLoadOp>(
- ordinalGlobalOp.getLoc(), ordinalGlobalOp);
+ Value loadedOrdinal =
+ ordinalGlobalOp.createLoadOp(ordinalGlobalOp.getLoc(), setterBuilder)
+ .getLoadedGlobalValue();
Value bufferOffset = setterBuilder.create<arith::MulIOp>(
loadedOrdinal.getLoc(), loadedOrdinal, elementSizeI32);
Value loadedValue = setterBuilder.create<IREE::Util::BufferLoadOp>(
@@ -121,8 +122,8 @@
setterBuilder.getIndexType(),
bufferOffset),
elementSizeIndex);
- setterBuilder.create<IREE::Util::GlobalStoreOp>(
- valueGlobalOp.getLoc(), loadedValue, valueGlobalOp);
+ valueGlobalOp.createStoreOp(valueGlobalOp.getLoc(), loadedValue,
+ setterBuilder);
}
setterBuilder.create<func::ReturnOp>(setterOp.getLoc());
}
diff --git a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp
index 070e53d..1549a37 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp
@@ -284,30 +284,27 @@
// %0 = util.global.load @foo : tensor<?xf32>
// %d = util.global.load @foo__d0 : index
// %1 = flow.tensor.tie_shape %0 : tensor<?xf32>{%d}
-static void expandGlobalLoadOp(IREE::Util::GlobalLoadOp op,
+static void expandGlobalLoadOp(IREE::Util::GlobalLoadOpInterface op,
ExpandedGlobalMap &globalMap, IndexSet &indexSet,
TensorDimMap &tensorDimMap) {
if (!usesDynamicTensors(op))
return;
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
- auto indexType = builder.getIndexType();
- auto &expandedGlobal = globalMap[op.getGlobal()];
+ auto &expandedGlobal = globalMap[op.getGlobalName()];
ExpandedValue expandedValue;
- expandedValue.tensor = op.getResult();
+ expandedValue.tensor = op.getLoadedGlobalValue();
expandedValue.dynamicDims.reserve(expandedGlobal.dynamicDimOps.size());
for (auto dimOp : expandedGlobal.dynamicDimOps) {
expandedValue.dynamicDims.push_back(
- builder
- .create<IREE::Util::GlobalLoadOp>(op.getLoc(), indexType,
- dimOp.getName())
- .getResult());
+ dimOp.createLoadOp(op.getLoc(), builder).getLoadedGlobalValue());
}
- tensorDimMap[op.getResult()] = expandedValue;
+ tensorDimMap[op.getLoadedGlobalValue()] = expandedValue;
auto tieShapeOp = builder.create<IREE::Flow::TensorTieShapeOp>(
op.getLoc(), expandedValue.tensor.getType(), expandedValue.tensor,
expandedValue.dynamicDims);
- op.getResult().replaceAllUsesExcept(tieShapeOp.getResult(), tieShapeOp);
+ op.getLoadedGlobalValue().replaceAllUsesExcept(tieShapeOp.getResult(),
+ tieShapeOp);
}
// Moves tensor dims from global stores to loads.
@@ -319,7 +316,7 @@
// ->
// util.global.store %0, @foo : tensor<?xf32>
// util.global.store %d, @foo__d0 : index
-static void expandGlobalStoreOp(IREE::Util::GlobalStoreOp op,
+static void expandGlobalStoreOp(IREE::Util::GlobalStoreOpInterface op,
ExpandedGlobalMap &globalMap,
IndexSet &indexSet,
TensorDimMap &tensorDimMap) {
@@ -327,15 +324,14 @@
return;
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
- auto expandedValue = consumeExpandedValue(op.getLoc(), op.getValue(),
- tensorDimMap, indexSet, builder);
- auto &expandedGlobal = globalMap[op.getGlobal()];
- builder.create<IREE::Util::GlobalStoreOp>(op.getLoc(), expandedValue.tensor,
- expandedGlobal.tensorOp.getName());
+ auto expandedValue = consumeExpandedValue(
+ op.getLoc(), op.getStoredGlobalValue(), tensorDimMap, indexSet, builder);
+ auto &expandedGlobal = globalMap[op.getGlobalName()];
+ expandedGlobal.tensorOp.createStoreOp(op.getLoc(), expandedValue.tensor,
+ builder);
for (auto [valueDynamicDims, globalDynamicDimOps] : llvm::zip_equal(
expandedValue.dynamicDims, expandedGlobal.dynamicDimOps)) {
- builder.create<IREE::Util::GlobalStoreOp>(op.getLoc(), valueDynamicDims,
- globalDynamicDimOps.getName());
+ globalDynamicDimOps.createStoreOp(op.getLoc(), valueDynamicDims, builder);
}
op.erase();
}
@@ -550,9 +546,9 @@
// Recursively expands tensors into (tensor, dynamic dims...) in |op|.
static void expandTensorDims(Operation *op, ExpandedGlobalMap &globalMap,
IndexSet &indexSet, TensorDimMap &tensorDimMap) {
- if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOp>(op)) {
+ if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOpInterface>(op)) {
expandGlobalLoadOp(loadOp, globalMap, indexSet, tensorDimMap);
- } else if (auto storeOp = dyn_cast<IREE::Util::GlobalStoreOp>(op)) {
+ } else if (auto storeOp = dyn_cast<IREE::Util::GlobalStoreOpInterface>(op)) {
expandGlobalStoreOp(storeOp, globalMap, indexSet, tensorDimMap);
} else if (auto initializerOp = dyn_cast<IREE::Util::InitializerOp>(op)) {
expandInitializerOp(initializerOp, globalMap, indexSet, tensorDimMap);
diff --git a/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp b/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp
index 9e47513..8f680ad 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/ImportMLProgram.cpp
@@ -174,9 +174,8 @@
auto funcOp = b.create<func::FuncOp>(getterName, funcType);
funcOp.setPublic();
b.setInsertionPointToStart(funcOp.addEntryBlock());
- auto val = b.create<IREE::Util::GlobalLoadOp>(
- newType, SymbolRefAttr::get(globalOp.getSymNameAttr()));
- b.create<func::ReturnOp>(val.getResult());
+ auto val = globalOp.createLoadOp(globalOp.getLoc(), b);
+ b.create<func::ReturnOp>(val.getLoadedGlobalValue());
}
if (!setterName.empty() && isMutable) {
@@ -187,8 +186,7 @@
auto funcOp = b.create<func::FuncOp>(setterName, funcType);
funcOp.setPublic();
b.setInsertionPointToStart(funcOp.addEntryBlock());
- b.create<IREE::Util::GlobalStoreOp>(funcOp.getArgument(0),
- globalOp.getSymNameAttr());
+ globalOp.createStoreOp(globalOp.getLoc(), funcOp.getArgument(0), b);
b.create<func::ReturnOp>();
}
diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/MaterializeExecutables.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/MaterializeExecutables.cpp
index b1a1ae0..8b0ca0c 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/MaterializeExecutables.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/MaterializeExecutables.cpp
@@ -19,7 +19,9 @@
namespace mlir::iree_compiler::IREE::HAL::Loader {
-static void replaceExecutableWithGlobal(IREE::HAL::ExecutableOp executableOp) {
+static void replaceExecutableWithGlobal(
+ IREE::HAL::ExecutableOp executableOp,
+ DenseMap<Attribute, IREE::Util::GlobalOpInterface> &executableGlobalOps) {
OpBuilder moduleBuilder(executableOp);
auto loc = executableOp.getLoc();
@@ -64,8 +66,7 @@
{
auto exitBuilder = OpBuilder::atBlockBegin(exitBlock);
auto executableArg = exitBlock->addArgument(executableType, loc);
- exitBuilder.create<IREE::Util::GlobalStoreOp>(loc, executableArg,
- globalOp.getName());
+ globalOp.createStoreOp(loc, executableArg, exitBuilder);
exitBuilder.create<IREE::Util::ReturnOp>(loc);
}
@@ -120,6 +121,10 @@
ValueRange{executable});
}
+ // Stash for faster lookup when replacing using.
+ executableGlobalOps[FlatSymbolRefAttr::get(globalOp.getNameAttr())] =
+ globalOp;
+
// Op goes away to get replaced with a global.
executableOp.erase();
}
@@ -138,18 +143,20 @@
mlir::ModuleOp moduleOp = getOperation();
// Walk executables and convert each one to a global.
+ DenseMap<Attribute, IREE::Util::GlobalOpInterface> executableGlobalOps;
for (auto executableOp : llvm::make_early_inc_range(
moduleOp.getOps<IREE::HAL::ExecutableOp>())) {
- replaceExecutableWithGlobal(executableOp);
+ replaceExecutableWithGlobal(executableOp, executableGlobalOps);
}
// Find lookup ops referencing an executable and swap it to a global load.
for (auto funcOp : llvm::make_early_inc_range(
moduleOp.getOps<mlir::FunctionOpInterface>())) {
funcOp.walk([&](IREE::HAL::Loader::ExecutableLookupOp lookupOp) {
- Value executable = OpBuilder(lookupOp).create<IREE::Util::GlobalLoadOp>(
- lookupOp.getLoc(), lookupOp.getResult().getType(),
- lookupOp.getExecutableAttr());
+ OpBuilder builder(lookupOp);
+ auto globalOp = executableGlobalOps[lookupOp.getExecutableAttr()];
+ Value executable = globalOp.createLoadOp(lookupOp.getLoc(), builder)
+ .getLoadedGlobalValue();
lookupOp.replaceAllUsesWith(executable);
lookupOp.erase();
});
diff --git a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
index 4887829..38ede9a 100644
--- a/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
+++ b/compiler/src/iree/compiler/Modules/IO/Parameters/Transforms/ExportParameters.cpp
@@ -72,8 +72,8 @@
arith::ConstantOp originalConstant = it.first;
Util::GlobalOp globalOp = it.second;
rewriter.setInsertionPointAfterValue(originalConstant);
- Value load =
- rewriter.create<Util::GlobalLoadOp>(globalOp->getLoc(), globalOp);
+ Value load = globalOp.createLoadOp(globalOp.getLoc(), rewriter)
+ .getLoadedGlobalValue();
rewriter.replaceOp(originalConstant, load);
}
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index 66f60ba..9d7bab7 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -35,8 +35,9 @@
bool LayoutAttr::isValidLayout(ArrayRef<int64_t> shape) const {
for (auto perDimLayout : llvm::enumerate(getLayouts())) {
ArrayRef<int64_t> layoutShape = perDimLayout.value().getShapes();
- int64_t computedShape = std::reduce(layoutShape.begin(), layoutShape.end(),
- 1, std::multiplies<int64_t>());
+ int64_t computedShape =
+ std::reduce(layoutShape.begin(), layoutShape.end(),
+ static_cast<int64_t>(1), std::multiplies<int64_t>());
int64_t expectedShape = shape[perDimLayout.index()];
if (computedShape != expectedShape) {
return false;