Allowing ranked_shapes in variables and adding support for expanding dims.
Cleanup as part of #4675 should move this pass to a shared place so it
can be run both in flow and hal.
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 7494fd1..959318a 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -92,7 +92,7 @@
FLOW_VariableRefAttr:$variable
);
let results = (outs
- AnyRankedTensor:$result
+ AnyType:$result
);
let assemblyFormat = "$variable attr-dict `:` type($result)";
@@ -112,7 +112,7 @@
FLOW_VariablePtr:$variable
);
let results = (outs
- AnyRankedTensor:$result
+ AnyType:$result
);
let assemblyFormat = "$variable attr-dict `:` type($variable) `->` type($result)";
@@ -129,7 +129,7 @@
}];
let arguments = (ins
- AnyRankedTensor:$value,
+ AnyType:$value,
FLOW_VariableRefAttr:$variable
);
@@ -147,7 +147,7 @@
}];
let arguments = (ins
- AnyRankedTensor:$value,
+ AnyType:$value,
FLOW_VariablePtr:$variable
);
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index c32aba1..f0cf29e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -27,6 +27,7 @@
"DispatchConfig.cpp",
"DispatchLinalgOnTensors.cpp",
"DispatchabilityAnalysis.cpp",
+ "ExpandVariableDynamicDims.cpp",
"FlattenTuplesInCFG.cpp",
"FoldCompatibleDispatchRegions.cpp",
"FormStreams.cpp",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 4c2f4ac..8f07e4f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -28,6 +28,7 @@
"DispatchConfig.cpp"
"DispatchLinalgOnTensors.cpp"
"DispatchabilityAnalysis.cpp"
+ "ExpandVariableDynamicDims.cpp"
"FlattenTuplesInCFG.cpp"
"FoldCompatibleDispatchRegions.cpp"
"FormStreams.cpp"
diff --git a/iree/compiler/Dialect/Flow/Transforms/ExpandVariableDynamicDims.cpp b/iree/compiler/Dialect/Flow/Transforms/ExpandVariableDynamicDims.cpp
new file mode 100644
index 0000000..d588938
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/ExpandVariableDynamicDims.cpp
@@ -0,0 +1,151 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+class ExpandVariableDynamicDimsPass
+ : public PassWrapper<ExpandVariableDynamicDimsPass,
+ OperationPass<ModuleOp>> {
+ public:
+ ExpandVariableDynamicDimsPass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Flow::FlowDialect>();
+ registry.insert<ShapeDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+
+ // Gathers all of the flow.variables containing shapes.
+ SmallVector<VariableOp, 4> shapeVarOps;
+ moduleOp.walk([&](VariableOp op) {
+ if (op.type().isa<Shape::RankedShapeType>()) {
+ shapeVarOps.push_back(op);
+ }
+ });
+
+ // Split each variable into one variable per dimension.
+ for (auto shapeVarOp : shapeVarOps) {
+ expandShapeVariable(moduleOp, shapeVarOp);
+ }
+ }
+
+ private:
+ // Expands a flow.variable representing a shape with one variable per dim.
+ // Uses of the variable will be replaced with the per-dim ones.
+ void expandShapeVariable(ModuleOp moduleOp, VariableOp shapeVarOp) {
+ // Create one flow.variable per dimension (static or dynamic).
+ OpBuilder moduleBuilder(shapeVarOp);
+ auto shapeType = shapeVarOp.type().cast<Shape::RankedShapeType>();
+ SmallVector<VariableOp, 4> dimVarOps;
+ for (int i = 0; i < shapeType.getRank(); ++i) {
+ Attribute initialDimValue;
+ if (shapeType.isDimDynamic(i)) {
+ // Right now we choose zero for initial dynamic dim values but this
+ // needs to agree with bindings that may have expectations on query.
+ // 0 is at least easier to gracefully bail on when values are never
+ // overridden.
+ initialDimValue = moduleBuilder.getIndexAttr(0);
+ } else {
+ initialDimValue = moduleBuilder.getIndexAttr(shapeType.getStaticDim(i));
+ }
+ auto dimVarOp = moduleBuilder.create<VariableOp>(
+ shapeVarOp.getLoc(),
+ (shapeVarOp.getName() + "_d" + std::to_string(i)).str(),
+ /*isMutable=*/shapeType.isDimDynamic(i), moduleBuilder.getIndexType(),
+ initialDimValue);
+ dimVarOp.setPrivate();
+ dimVarOps.push_back(dimVarOp);
+ }
+
+ // Replace all uses of the single variable with the split ones.
+ replaceShapeVariableUses(moduleOp, shapeType, shapeVarOp, dimVarOps);
+
+ // Erase the original variable.
+ shapeVarOp.erase();
+ }
+
+ // Replaces uses of |shapeVarOp| in |moduleOp| with the expanded |dimVarOps|.
+ void replaceShapeVariableUses(ModuleOp moduleOp,
+ Shape::RankedShapeType shapeType,
+ VariableOp shapeVarOp,
+ ArrayRef<VariableOp> dimVarOps) {
+ auto allUses = SymbolTable::getSymbolUses(shapeVarOp, moduleOp)
+ .getValueOr(SymbolTable::UseRange({}));
+ for (auto use : allUses) {
+ if (auto loadOp = dyn_cast<VariableLoadOp>(use.getUser())) {
+ OpBuilder builder(loadOp);
+ SmallVector<Value, 4> dynamicDimValues;
+ for (int i = 0; i < shapeType.getRank(); ++i) {
+ if (!shapeType.isDimDynamic(i)) continue;
+ VariableOp dimVarOp = dimVarOps[i];
+ dynamicDimValues.push_back(builder.create<VariableLoadOp>(
+ loadOp.getLoc(), builder.getIndexType(), dimVarOp.getName()));
+ }
+ auto shapeValue = builder.create<Shape::MakeRankedShapeOp>(
+ loadOp.getLoc(), shapeType, dynamicDimValues);
+ loadOp->replaceAllUsesWith(shapeValue);
+ loadOp.erase();
+ } else if (auto storeOp = dyn_cast<VariableStoreOp>(use.getUser())) {
+ OpBuilder builder(storeOp);
+ auto shapeValue = storeOp.value();
+ for (int i = 0; i < shapeType.getRank(); ++i) {
+ if (!shapeType.isDimDynamic(i)) continue;
+ VariableOp dimVarOp = dimVarOps[i];
+ auto dynamicDimValue = builder.createOrFold<Shape::RankedDimOp>(
+ storeOp.getLoc(), shapeValue, i);
+ builder.create<VariableStoreOp>(storeOp.getLoc(), dynamicDimValue,
+ dimVarOp.getName());
+ }
+ storeOp.erase();
+ } else {
+ // TODO(benvanik): support indirection/addressing - should be fairly
+ // easy to do by splitting the address ops to each dim.
+ use.getUser()->emitError()
+ << "variable action on shape is not yet supported";
+ signalPassFailure();
+ return;
+ }
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createExpandVariableDynamicDimsPass() {
+ return std::make_unique<ExpandVariableDynamicDimsPass>();
+}
+
+static PassRegistration<ExpandVariableDynamicDimsPass> pass(
+ "iree-flow-expand-variable-dynamic-dims",
+ "Expands !shapex.ranked_shape dynamic dimensions stored in variables.",
+ [] { return std::make_unique<ExpandVariableDynamicDimsPass>(); });
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index fffc559..8ecaa6a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -122,6 +122,11 @@
passManager.addNestedPass<FuncOp>(
IREE::Flow::createMaterializeExportedReflection());
+ // Replaces variables with !shapex.ranked_shape types with individual
+ // variables for each dimension. This allows for constant dimensions to be
+ // DCE'd in following passes.
+ passManager.addPass(IREE::Flow::createExpandVariableDynamicDimsPass());
+
// Materialize dynamic shapes in the IR, also expanding function signatures
// such that:
// - Dynamic ranked tensors: (tensor<?x?xf32>) expands to
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index a1bc2a0..31ea5a0 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -96,6 +96,9 @@
// expected.
std::unique_ptr<OperationPass<FuncOp>> createMergeExportedReflection();
+// Expands dynamic !shapex.ranked_shape dimensions in variables.
+std::unique_ptr<OperationPass<ModuleOp>> createExpandVariableDynamicDimsPass();
+
//===----------------------------------------------------------------------===//
// Dispatches (flow.dispatch.region)
//===----------------------------------------------------------------------===//
@@ -193,6 +196,7 @@
createPostPartitioningConversionPass();
createMaterializeExportedReflection();
createMergeExportedReflection();
+ createExpandVariableDynamicDimsPass();
createDispatchabilityAnalysisPass();
createIdentifyDispatchRegionsPass();
createIdentifyDispatchRegions2Pass();
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/expand_variable_dynamic_dims.mlir b/iree/compiler/Dialect/Flow/Transforms/test/expand_variable_dynamic_dims.mlir
new file mode 100644
index 0000000..2588b9a
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/expand_variable_dynamic_dims.mlir
@@ -0,0 +1,73 @@
+// RUN: iree-opt -split-input-file -iree-flow-expand-variable-dynamic-dims %s | IreeFileCheck %s
+
+// CHECK-NOT: flow.variable @static_var mutable
+// CHECK: flow.variable @static_var_d0 1 : index
+// CHECK: flow.variable @static_var_d1 2 : index
+// CHECK: flow.variable @static_var_d2 3 : index
+// CHECK: flow.variable @static_var_d3 4 : index
+flow.variable @static_var mutable : !shapex.ranked_shape<[1,2,3,4]>
+// CHECK-LABEL: func @static_loads
+func @static_loads() -> (index, index, index, index) {
+ // CHECK-NOT: flow.variable.load
+ // CHECK-NEXT: %[[SHAPE:.+]] = shapex.make_ranked_shape
+ %0 = flow.variable.load @static_var : !shapex.ranked_shape<[1,2,3,4]>
+ // CHECK-DAG: %[[D0:.+]] = shapex.ranked_dim %[[SHAPE]][0]
+ %1 = shapex.ranked_dim %0[0] : !shapex.ranked_shape<[1,2,3,4]> -> index
+ // CHECK-DAG: %[[D1:.+]] = shapex.ranked_dim %[[SHAPE]][1]
+ %2 = shapex.ranked_dim %0[1] : !shapex.ranked_shape<[1,2,3,4]> -> index
+ // CHECK-DAG: %[[D2:.+]] = shapex.ranked_dim %[[SHAPE]][2]
+ %3 = shapex.ranked_dim %0[2] : !shapex.ranked_shape<[1,2,3,4]> -> index
+ // CHECK-DAG: %[[D3:.+]] = shapex.ranked_dim %[[SHAPE]][3]
+ %4 = shapex.ranked_dim %0[3] : !shapex.ranked_shape<[1,2,3,4]> -> index
+ // CHECK-NEXT: return %[[D0]], %[[D1]], %[[D2]], %[[D3]]
+ return %1, %2, %3, %4 : index, index, index, index
+}
+// CHECK-LABEL: func @static_stores
+func @static_stores(%arg0 : index, %arg1 : index) {
+ // CHECK-NEXT: %[[SHAPE:.+]] = shapex.const_ranked_shape
+ %0 = shapex.const_ranked_shape : !shapex.ranked_shape<[1,2,3,4]>
+ // CHECK-NOT: flow.variable.store
+ flow.variable.store %0, @static_var : !shapex.ranked_shape<[1,2,3,4]>
+ // CHECK-NEXT: return
+ return
+}
+
+// -----
+
+// CHECK-NOT: flow.variable @dynamic_var mutable
+// CHECK: flow.variable @dynamic_var_d0 1 : index
+// CHECK: flow.variable @dynamic_var_d1 mutable 0 : index
+// CHECK: flow.variable @dynamic_var_d2 mutable 0 : index
+// CHECK: flow.variable @dynamic_var_d3 4 : index
+flow.variable @dynamic_var mutable : !shapex.ranked_shape<[1,?,?,4]>
+// CHECK-LABEL: func @dynamic_loads
+func @dynamic_loads() -> (index, index, index, index) {
+ // CHECK-NOT: flow.variable.load @dynamic_var
+ // CHECK-DAG: %[[D1:.+]] = flow.variable.load @dynamic_var_d1 : index
+ // CHECK-DAG: %[[D2:.+]] = flow.variable.load @dynamic_var_d2 : index
+ // CHECK-NEXT: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[D1]], %[[D2]] : (index, index) -> !shapex.ranked_shape<[1,?,?,4]>
+ %0 = flow.variable.load @dynamic_var : !shapex.ranked_shape<[1,?,?,4]>
+ // CHECK-DAG: = shapex.ranked_dim %[[SHAPE]][0] : !shapex.ranked_shape<[1,?,?,4]> -> index
+ %1 = shapex.ranked_dim %0[0] : !shapex.ranked_shape<[1,?,?,4]> -> index
+ // CHECK-DAG: = shapex.ranked_dim %[[SHAPE]][1] : !shapex.ranked_shape<[1,?,?,4]> -> index
+ %2 = shapex.ranked_dim %0[1] : !shapex.ranked_shape<[1,?,?,4]> -> index
+ // CHECK-DAG: = shapex.ranked_dim %[[SHAPE]][2] : !shapex.ranked_shape<[1,?,?,4]> -> index
+ %3 = shapex.ranked_dim %0[2] : !shapex.ranked_shape<[1,?,?,4]> -> index
+ // CHECK-DAG: = shapex.ranked_dim %[[SHAPE]][3] : !shapex.ranked_shape<[1,?,?,4]> -> index
+ %4 = shapex.ranked_dim %0[3] : !shapex.ranked_shape<[1,?,?,4]> -> index
+ // CHECK-NEXT: return
+ return %1, %2, %3, %4 : index, index, index, index
+}
+// CHECK-LABEL: func @dynamic_stores
+// CHECK-SAME: (%[[D1:.+]]: index, %[[D2:.+]]: index)
+func @dynamic_stores(%arg0 : index, %arg1 : index) {
+ // CHECK-NEXT: %[[SHAPE:.+]] = shapex.make_ranked_shape %arg0, %arg1
+ %0 = shapex.make_ranked_shape %arg0, %arg1 : (index, index) -> !shapex.ranked_shape<[1,?,?,4]>
+ // CHECK-NEXT: %[[D1:.+]] = shapex.ranked_dim %[[SHAPE]][1] : !shapex.ranked_shape<[1,?,?,4]> -> index
+ // CHECK-NEXT: flow.variable.store %[[D1]], @dynamic_var_d1 : index
+ // CHECK-NEXT: %[[D2:.+]] = shapex.ranked_dim %[[SHAPE]][2] : !shapex.ranked_shape<[1,?,?,4]> -> index
+ // CHECK-NEXT: flow.variable.store %[[D2]], @dynamic_var_d2 : index
+ flow.variable.store %0, @dynamic_var : !shapex.ranked_shape<[1,?,?,4]>
+ // CHECK-NEXT: return
+ return
+}
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp
index 9928153..4315fc5 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp
@@ -38,14 +38,24 @@
op.initial_value(), llvm::to_vector<4>(op->getDialectAttrs()));
return success();
} else if (convertedType.isInteger(32)) {
+ auto convertedValue =
+ op.initial_value().hasValue()
+ ? rewriter.getI32IntegerAttr(static_cast<int32_t>(
+ op.initial_value().getValue().cast<IntegerAttr>().getInt()))
+ : Attribute{};
rewriter.replaceOpWithNewOp<IREE::VM::GlobalI32Op>(
op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
- op.initial_value(), llvm::to_vector<4>(op->getDialectAttrs()));
+ convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
return success();
} else if (convertedType.isInteger(64)) {
+ auto convertedValue =
+ op.initial_value().hasValue()
+ ? rewriter.getI64IntegerAttr(
+ op.initial_value().getValue().cast<IntegerAttr>().getInt())
+ : Attribute{};
rewriter.replaceOpWithNewOp<IREE::VM::GlobalI64Op>(
op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
- op.initial_value(), llvm::to_vector<4>(op->getDialectAttrs()));
+ convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
return success();
}
return op.emitOpError("unsupported variable type");