Adding hal.variable.* ops to match the flow.variable.* ops.
Noticed that I got the order wrong on flow.variable.store and corrected that so that the two match.
PiperOrigin-RevId: 283454948
diff --git a/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp b/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
index c4bdadf..0ae2056 100644
--- a/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
+++ b/integrations/tensorflow/compiler/TFSavedModelAdoptExports.cpp
@@ -87,8 +87,8 @@
dyn_cast<TF::AssignVariableOp>(operand.getOwner())) {
OpBuilder(assign_variable)
.create<IREE::Flow::VariableStoreOp>(assign_variable.getLoc(),
- flow_sym_ref,
- assign_variable.value());
+ assign_variable.value(),
+ flow_sym_ref);
assign_variable.erase();
continue;
}
diff --git a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
index f017a43..181e4a7 100644
--- a/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
+++ b/integrations/tensorflow/compiler/test/saved_model_adopt_exports.py
@@ -75,7 +75,7 @@
# CHECK: flow.variable @v mutable dense<0.000000e+00> : tensor<f32>
# CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.index_path = [0]})
# CHECK: attributes{{.*}}iree.module.export
-# CHECK: flow.variable.store @v, %arg0 : tensor<f32>
+# CHECK: flow.variable.store %arg0, @v : tensor<f32>
# CHECK: FINISH_TEST
class T0002b_SimpleVarWrite(tf.Module):
@@ -109,8 +109,8 @@
# CHECK: attributes{{.*}}iree.module.export
# CHECK: [[CONST_2xf32:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32>
# CHECK: [[CONST_3xf32:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
-# CHECK: flow.variable.store @v, [[CONST_2xf32]] : tensor<2xf32>
-# CHECK: flow.variable.store @v, [[CONST_3xf32]] : tensor<3xf32>
+# CHECK: flow.variable.store [[CONST_2xf32]], @v : tensor<2xf32>
+# CHECK: flow.variable.store [[CONST_3xf32]], @v : tensor<3xf32>
# CHECK: FINISH_TEST
class T0002d_VarCompatibleShapeChange(tf.Module):
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index f7dc689..10e51cd 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -222,12 +222,12 @@
static ParseResult parseVariableStoreOp(OpAsmParser &parser,
OperationState *result) {
- FlatSymbolRefAttr variableAttr;
OpAsmParser::OperandType value;
+ FlatSymbolRefAttr variableAttr;
Type valueType;
- if (failed(parser.parseAttribute(variableAttr, "variable",
+ if (failed(parser.parseOperand(value)) || failed(parser.parseComma()) ||
+ failed(parser.parseAttribute(variableAttr, "variable",
result->attributes)) ||
- failed(parser.parseComma()) || failed(parser.parseOperand(value)) ||
failed(parser.parseOptionalAttrDict(result->attributes)) ||
failed(parser.parseColonType(valueType)) ||
failed(parser.resolveOperand(value, valueType, result->operands))) {
@@ -238,9 +238,9 @@
static void printVariableStoreOp(OpAsmPrinter &p, VariableStoreOp &op) {
p << op.getOperationName() << ' ';
- p.printSymbolName(op.variable());
- p << ", ";
p.printOperand(op.value());
+ p << ", ";
+ p.printSymbolName(op.variable());
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"});
p << " : ";
p.printType(op.value()->getType());
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index cf38705..fee0540 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -86,8 +86,8 @@
}];
let arguments = (ins
- FLOW_VariableRefAttr:$variable,
- AnyRankedTensor:$value
+ AnyRankedTensor:$value,
+ FLOW_VariableRefAttr:$variable
);
let verifier = [{ return verifyVariableStoreOp(*this); }];
diff --git a/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir
index 6c0766b..102c6d7 100644
--- a/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/variable_folding.mlir
@@ -40,7 +40,7 @@
func @nop_load_store() {
// CHECK-NEXT: return
%0 = flow.variable.load @v_nop : tensor<4xi32>
- flow.variable.store @v_nop, %0 : tensor<4xi32>
+ flow.variable.store %0, @v_nop : tensor<4xi32>
return
}
diff --git a/iree/compiler/Dialect/Flow/IR/test/variable_ops.mlir b/iree/compiler/Dialect/Flow/IR/test/variable_ops.mlir
index a8f1182..22ab9f5 100644
--- a/iree/compiler/Dialect/Flow/IR/test/variable_ops.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/variable_ops.mlir
@@ -37,7 +37,7 @@
flow.variable @v_loaded : tensor<4xi32>
// CHECK-LABEL: @loaded
func @loaded() {
- // CHECK-NEXT: %0 = flow.variable.load @v_loaded : tensor<4xi32>
+ // CHECK-NEXT: = flow.variable.load @v_loaded : tensor<4xi32>
%0 = flow.variable.load @v_loaded : tensor<4xi32>
return
}
@@ -47,9 +47,9 @@
flow.variable @v_stored mutable : tensor<4xi32>
// CHECK-LABEL: @stored
func @stored() {
- // CHECK-NEXT: = constant
+ // CHECK-NEXT: [[VAL:%.+]] = constant
%cst = constant dense<5> : tensor<4xi32>
- // CHECK-NEXT: flow.variable.store @v_stored, %cst : tensor<4xi32>
- flow.variable.store @v_stored, %cst : tensor<4xi32>
+ // CHECK-NEXT: flow.variable.store [[VAL]], @v_stored : tensor<4xi32>
+ flow.variable.store %cst, @v_stored : tensor<4xi32>
return
}
diff --git a/iree/compiler/Dialect/HAL/IR/HALBase.td b/iree/compiler/Dialect/HAL/IR/HALBase.td
index a6f3182..cb05b18 100644
--- a/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -338,6 +338,9 @@
def HAL_PrimitiveType : AnyTypeOf<[Index, AnyInteger, AnyFloat]>;
+def HAL_VariableRefAttr : AliasedSymbolRefAttr;
+def HAL_VariableType : AnyTypeOf<[HAL_PrimitiveType, AnyVector, AnyRefPtr]>;
+
def HAL_Dim : I<32>;
def HAL_Dims : Variadic<HAL_Dim>;
def HAL_Shape : Variadic<HAL_Dim>;
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index ff89277..f002e65 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -15,6 +15,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/StringExtras.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
@@ -30,6 +31,100 @@
namespace HAL {
//===----------------------------------------------------------------------===//
+// Variables
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Converts variable initializer functions that evaluate to a constant to a
+/// specified initial value.
+struct InlineConstVariableOpInitializer : public OpRewritePattern<VariableOp> {
+ using OpRewritePattern<VariableOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(VariableOp op,
+ PatternRewriter &rewriter) const override {
+ if (!op.initializer()) return matchFailure();
+ auto *symbolOp =
+ SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue());
+ auto initializer = cast<FuncOp>(symbolOp);
+ if (initializer.getBlocks().size() == 1 &&
+ initializer.getBlocks().front().getOperations().size() == 2 &&
+ isa<mlir::ReturnOp>(
+ initializer.getBlocks().front().getOperations().back())) {
+ auto &primaryOp = initializer.getBlocks().front().getOperations().front();
+ Attribute constResult;
+ if (matchPattern(primaryOp.getResult(0), m_Constant(&constResult))) {
+ rewriter.replaceOpWithNewOp<VariableOp>(
+ op, op.sym_name(), op.is_mutable(), op.type(), constResult);
+ return matchSuccess();
+ }
+ }
+ return matchFailure();
+ }
+};
+
+} // namespace
+
+void VariableOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<InlineConstVariableOpInitializer>(context);
+}
+
+namespace {
+
+/// Erases hal.variable.load ops whose values are unused.
+/// We have to do this manually as the load op cannot be marked pure and have it
+/// done automatically.
+struct EraseUnusedVariableLoadOp : public OpRewritePattern<VariableLoadOp> {
+ using OpRewritePattern<VariableLoadOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(VariableLoadOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.result()->use_empty()) {
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+
+} // namespace
+
+void VariableLoadOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<EraseUnusedVariableLoadOp>(context);
+}
+
+namespace {
+
+/// Erases hal.variable.store ops that are no-ops.
+/// This can happen if there was a variable load, some DCE'd usage, and a
+/// store back to the same variable: we want to be able to elide the entire load
+/// and store.
+struct EraseUnusedVariableStoreOp : public OpRewritePattern<VariableStoreOp> {
+ using OpRewritePattern<VariableStoreOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(VariableStoreOp op,
+ PatternRewriter &rewriter) const override {
+ if (auto loadOp =
+ dyn_cast_or_null<VariableLoadOp>(op.value()->getDefiningOp())) {
+ if (loadOp.variable() == op.variable()) {
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+ }
+ return matchFailure();
+ }
+};
+
+} // namespace
+
+void VariableStoreOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<EraseUnusedVariableStoreOp>(context);
+}
+
+//===----------------------------------------------------------------------===//
// iree::hal::Allocator
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 6df4911..26be4f2 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/TypeUtilities.h"
namespace mlir {
namespace iree_compiler {
@@ -338,6 +339,259 @@
}
//===----------------------------------------------------------------------===//
+// hal.variable
+//===----------------------------------------------------------------------===//
+
+// Returns true if the given |accessType| is compatible with the |variableType|.
+// For example, this will return true if the variable type is a tensor<?xf32>
+// and the access is tensor<4xf32>.
+static bool isVariableTypeCompatible(Type variableType, Type accessType) {
+ return succeeded(mlir::verifyCompatibleShape(variableType, accessType));
+}
+
+static ParseResult parseVariableOp(OpAsmParser &parser,
+ OperationState *result) {
+ StringAttr nameAttr;
+ if (failed(parser.parseSymbolName(nameAttr,
+ mlir::SymbolTable::getSymbolAttrName(),
+ result->attributes))) {
+ return failure();
+ }
+
+ if (succeeded(parser.parseOptionalKeyword("mutable"))) {
+ result->addAttribute("is_mutable", UnitAttr::get(result->getContext()));
+ }
+
+ if (succeeded(parser.parseOptionalKeyword("init"))) {
+ FlatSymbolRefAttr initializerAttr;
+ if (failed(parser.parseLParen()) ||
+ failed(parser.parseAttribute(initializerAttr, "initializer",
+ result->attributes)) ||
+ failed(parser.parseRParen())) {
+ return failure();
+ }
+ }
+
+ if (failed(parser.parseOptionalColon())) {
+ Attribute initialValueAttr;
+ if (failed(parser.parseAttribute(initialValueAttr, "initial_value",
+ result->attributes))) {
+ return failure();
+ }
+ result->addAttribute("type", TypeAttr::get(initialValueAttr.getType()));
+ } else {
+ Type type;
+ if (failed(parser.parseType(type))) {
+ return failure();
+ }
+ result->addAttribute("type", TypeAttr::get(type));
+ }
+
+ return success();
+}
+
+static void printVariableOp(OpAsmPrinter &p, VariableOp op) {
+ p << op.getOperationName() << ' ';
+ p.printSymbolName(op.sym_name());
+ if (op.is_mutable()) {
+ p << " mutable";
+ }
+ if (op.initializer().hasValue()) {
+ p << " init(";
+ p.printSymbolName(op.initializer().getValue());
+ p << ')';
+ }
+ if (op.initial_value().hasValue()) {
+ p << ' ';
+ p.printAttribute(op.initial_value().getValue());
+ } else {
+ p << " : ";
+ p.printType(op.type());
+ }
+}
+
+static LogicalResult verifyVariableOp(VariableOp op) {
+ if (op.initializer().hasValue() && op.initial_value().hasValue()) {
+ return op.emitOpError()
+ << "variables can have either an initializer or an initial value";
+ } else if (op.initializer().hasValue()) {
+ // Ensure initializer returns the same type as the variable.
+ auto *symbolOp =
+ SymbolTable::lookupNearestSymbolFrom(op, op.initializer().getValue());
+ if (!symbolOp) {
+ return op.emitOpError() << "initializer function "
+ << op.initializer().getValue() << " not found";
+ }
+ auto initializerOp = dyn_cast<FuncOp>(symbolOp);
+ if (initializerOp.getNumArguments() != 0 ||
+ initializerOp.getNumResults() != 1 ||
+ initializerOp.getType().getResult(0) != op.type()) {
+ return op.emitOpError()
+ << "initializer type mismatch; variable " << op.sym_name()
+ << " is " << op.type() << " but initializer function "
+ << initializerOp.getName() << " is " << initializerOp.getType();
+ }
+ } else if (op.initial_value().hasValue()) {
+ // Ensure the value is something we can store in the variable
+ if (!isVariableTypeCompatible(op.type(), op.initial_value()->getType())) {
+ return op.emitOpError()
+ << "initial value type mismatch; variable " << op.sym_name()
+ << " is " << op.type() << " but initial value provided is "
+ << op.initial_value()->getType();
+ }
+ }
+ return success();
+}
+
+void VariableOp::build(Builder *builder, OperationState &state, StringRef name,
+ bool isMutable, Type type,
+ Optional<StringRef> initializer,
+ Optional<Attribute> initialValue,
+ ArrayRef<NamedAttribute> attrs) {
+ state.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder->getStringAttr(name));
+ if (isMutable) {
+ state.addAttribute("is_mutable", builder->getUnitAttr());
+ }
+ if (initializer.hasValue()) {
+ state.addAttribute("initializer",
+ builder->getSymbolRefAttr(initializer.getValue()));
+ } else if (initialValue.hasValue()) {
+ state.addAttribute("initial_value", initialValue.getValue());
+ }
+ state.addAttribute("type", TypeAttr::get(type));
+ state.attributes.append(attrs.begin(), attrs.end());
+}
+
+void VariableOp::build(Builder *builder, OperationState &state, StringRef name,
+ bool isMutable, mlir::FuncOp initializer,
+ ArrayRef<NamedAttribute> attrs) {
+ state.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder->getStringAttr(name));
+ if (isMutable) {
+ state.addAttribute("is_mutable", builder->getUnitAttr());
+ }
+ state.addAttribute("initializer", builder->getSymbolRefAttr(initializer));
+ state.addAttribute("type", TypeAttr::get(initializer.getType().getResult(0)));
+ state.attributes.append(attrs.begin(), attrs.end());
+}
+
+void VariableOp::build(Builder *builder, OperationState &result, StringRef name,
+ bool isMutable, Type type, Attribute initialValue,
+ ArrayRef<NamedAttribute> attrs) {
+ result.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder->getStringAttr(name));
+ if (isMutable) {
+ result.addAttribute("is_mutable", builder->getUnitAttr());
+ }
+ result.addAttribute("initial_value", initialValue);
+ result.addAttribute("type", TypeAttr::get(type));
+ result.attributes.append(attrs.begin(), attrs.end());
+}
+
+void VariableOp::build(Builder *builder, OperationState &result, StringRef name,
+ bool isMutable, Type type,
+ ArrayRef<NamedAttribute> attrs) {
+ result.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder->getStringAttr(name));
+ if (isMutable) {
+ result.addAttribute("is_mutable", builder->getUnitAttr());
+ }
+ result.addAttribute("type", TypeAttr::get(type));
+ result.attributes.append(attrs.begin(), attrs.end());
+}
+
+//===----------------------------------------------------------------------===//
+// hal.variable.load
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseVariableLoadOp(OpAsmParser &parser,
+ OperationState *result) {
+ FlatSymbolRefAttr variableAttr;
+ Type valueType;
+ if (failed(parser.parseAttribute(variableAttr, "variable",
+ result->attributes)) ||
+ failed(parser.parseOptionalAttrDict(result->attributes)) ||
+ failed(parser.parseColonType(valueType))) {
+ return failure();
+ }
+ result->addTypes({valueType});
+ return success();
+}
+
+static void printVariableLoadOp(OpAsmPrinter &p, VariableLoadOp &op) {
+ p << op.getOperationName() << ' ';
+ p.printSymbolName(op.variable());
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"});
+ p << " : ";
+ p.printType(op.result()->getType());
+}
+
+static LogicalResult verifyVariableLoadOp(VariableLoadOp &op) {
+ auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable());
+ if (!symbolOp) {
+ return op.emitOpError() << "undefined variable: " << op.variable();
+ }
+ auto variableOp = dyn_cast<VariableOp>(symbolOp);
+ auto loadType = op.result()->getType();
+ if (!isVariableTypeCompatible(variableOp.type(), loadType)) {
+ return op.emitOpError()
+ << "variable type mismatch; variable " << op.variable() << " is "
+ << variableOp.type() << " but load is " << loadType;
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// hal.variable.store
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseVariableStoreOp(OpAsmParser &parser,
+ OperationState *result) {
+ OpAsmParser::OperandType value;
+ FlatSymbolRefAttr variableAttr;
+ Type valueType;
+ if (failed(parser.parseOperand(value)) || failed(parser.parseComma()) ||
+ failed(parser.parseAttribute(variableAttr, "variable",
+ result->attributes)) ||
+ failed(parser.parseOptionalAttrDict(result->attributes)) ||
+ failed(parser.parseColonType(valueType)) ||
+ failed(parser.resolveOperand(value, valueType, result->operands))) {
+ return failure();
+ }
+ return success();
+}
+
+static void printVariableStoreOp(OpAsmPrinter &p, VariableStoreOp &op) {
+ p << op.getOperationName() << ' ';
+ p.printOperand(op.value());
+ p << ", ";
+ p.printSymbolName(op.variable());
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"variable"});
+ p << " : ";
+ p.printType(op.value()->getType());
+}
+
+static LogicalResult verifyVariableStoreOp(VariableStoreOp &op) {
+ auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(op, op.variable());
+ if (!symbolOp) {
+ return op.emitOpError() << "undefined variable: " << op.variable();
+ }
+ auto variableOp = dyn_cast<VariableOp>(symbolOp);
+ auto storeType = op.value()->getType();
+ if (!isVariableTypeCompatible(variableOp.type(), storeType)) {
+ return op.emitOpError()
+ << "variable type mismatch; variable " << op.variable() << " is "
+ << variableOp.type() << " but store is " << storeType;
+ }
+ if (!variableOp.is_mutable()) {
+ return op.emitOpError() << "variable " << op.variable()
+ << " is not mutable and cannot be stored to";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// hal.allocator.compute_size
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.h b/iree/compiler/Dialect/HAL/IR/HALOps.h
index c26602a..a669c5b 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.h
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.h
@@ -22,6 +22,7 @@
#include "iree/compiler/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 2379e7e..d6da64d 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -154,6 +154,91 @@
}
//===----------------------------------------------------------------------===//
+// Global variables
+//===----------------------------------------------------------------------===//
+
+def HAL_VariableOp : HAL_Op<"variable", [
+ Symbol,
+ ]> {
+ let summary = [{stateful variable declaration}];
+ let description = [{
+ Declares a global variable that maintains its value across invocations.
+ The value is tied to the execution context of the module and different
+ contexts will have different variable storage.
+ }];
+
+ let arguments = (ins
+ StrAttr:$sym_name,
+ TypeAttr:$type,
+ UnitAttr:$is_mutable,
+ // TODO(benvanik): verify matches $type.
+ OptionalAttr<FlatSymbolRefAttr>:$initializer,
+ // TODO(benvanik): verify matches $type.
+ OptionalAttr<AnyAttr>:$initial_value
+ );
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<[{
+ Builder *builder, OperationState &result, StringRef name, bool isMutable,
+ Type type,
+ Optional<StringRef> initializer, Optional<Attribute> initialValue,
+ ArrayRef<NamedAttribute> attrs = {}
+ }]>,
+ OpBuilder<[{
+ Builder *builder, OperationState &result, StringRef name, bool isMutable,
+ mlir::FuncOp initializer, ArrayRef<NamedAttribute> attrs = {}
+ }]>,
+ OpBuilder<[{
+ Builder *builder, OperationState &result, StringRef name, bool isMutable,
+ Type type, Attribute initialValue, ArrayRef<NamedAttribute> attrs = {}
+ }]>,
+ OpBuilder<[{
+ Builder *builder, OperationState &result, StringRef name, bool isMutable,
+ Type type, ArrayRef<NamedAttribute> attrs = {}
+ }]>,
+ ];
+
+ let verifier = [{ return verifyVariableOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def HAL_VariableLoadOp : HAL_Op<"variable.load"> {
+ let summary = [{loads a value from a global variable}];
+ let description = [{
+ Returns a copy of the variable value.
+ }];
+
+ let arguments = (ins
+ HAL_VariableRefAttr:$variable
+ );
+ let results = (outs
+ HAL_VariableType:$result
+ );
+
+ let verifier = [{ return verifyVariableLoadOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+def HAL_VariableStoreOp : HAL_Op<"variable.store"> {
+ let summary = [{stores a value into a global variable}];
+ let description = [{
+ Stores a copy of the value into a variable.
+ }];
+
+ let arguments = (ins
+ HAL_VariableType:$value,
+ HAL_VariableRefAttr:$variable
+ );
+
+ let verifier = [{ return verifyVariableStoreOp(*this); }];
+
+ let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
// iree::hal::Allocator
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/test/variable_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/variable_folding.mlir
new file mode 100644
index 0000000..31f56de
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/IR/test/variable_folding.mlir
@@ -0,0 +1,46 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests folding and canonicalization of variable ops.
+
+// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK: hal.variable @v_initialized 4 : i32
+hal.variable @v_initialized init(@initializer) : i32
+func @initializer() -> i32 {
+ %0 = constant 4 : i32
+ return %0 : i32
+}
+
+// -----
+
+hal.variable @v_unused : !ireex.ref<!hal.buffer>
+// CHECK-LABEL: @unused_load
+func @unused_load() {
+ // CHECK-NEXT: return
+ %0 = hal.variable.load @v_unused : !ireex.ref<!hal.buffer>
+ return
+}
+
+// -----
+
+hal.variable @v_nop mutable : !ireex.ref<!hal.buffer>
+// CHECK-LABEL: @nop_load_store
+func @nop_load_store() {
+ // CHECK-NEXT: return
+ %0 = hal.variable.load @v_nop : !ireex.ref<!hal.buffer>
+ hal.variable.store %0, @v_nop : !ireex.ref<!hal.buffer>
+ return
+}
+
diff --git a/iree/compiler/Dialect/HAL/IR/test/variable_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/variable_ops.mlir
new file mode 100644
index 0000000..3bf32db
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/IR/test/variable_ops.mlir
@@ -0,0 +1,55 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests printing and parsing of variable ops.
+
+// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
+
+// CHECK: hal.variable @v_immutable : tensor<i32>
+hal.variable @v_immutable : tensor<i32>
+// CHECK: hal.variable @v_mutable mutable : tensor<i32>
+hal.variable @v_mutable mutable : tensor<i32>
+
+// -----
+
+// CHECK: hal.variable @v_initialized_const 4 : i32
+hal.variable @v_initialized_const 4 : i32
+
+// -----
+
+// CHECK: hal.variable @v_initialized init(@initializer) : !ireex.ref<!hal.buffer>
+hal.variable @v_initialized init(@initializer) : !ireex.ref<!hal.buffer>
+func @initializer() -> !ireex.ref<!hal.buffer>
+
+// -----
+
+hal.variable @v_loaded : !ireex.ref<!hal.buffer>
+// CHECK-LABEL: @loaded
+func @loaded() {
+ // CHECK-NEXT: = hal.variable.load @v_loaded : !ireex.ref<!hal.buffer>
+ %0 = hal.variable.load @v_loaded : !ireex.ref<!hal.buffer>
+ return
+}
+
+// -----
+
+hal.variable @v_stored mutable : !ireex.ref<!hal.buffer>
+// CHECK-LABEL: @stored
+func @stored() {
+ // CHECK-NEXT: [[BUF:%.+]] = "test_hal.buffer"
+ %0 = "test_hal.buffer"() : () -> !ireex.ref<!hal.buffer>
+ // CHECK-NEXT: hal.variable.store [[BUF]], @v_stored : !ireex.ref<!hal.buffer>
+ hal.variable.store %0, @v_stored : !ireex.ref<!hal.buffer>
+ return
+}