Allowing ranked_shapes in variables and adding support for expanding dims.
Cleanup as part of #4675 should move this pass to a shared place so it
can be run both in flow and hal.
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 7494fd1..959318a 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -92,7 +92,7 @@
     FLOW_VariableRefAttr:$variable
   );
   let results = (outs
-    AnyRankedTensor:$result
+    AnyType:$result
   );
 
   let assemblyFormat = "$variable attr-dict `:` type($result)";
@@ -112,7 +112,7 @@
     FLOW_VariablePtr:$variable
   );
   let results = (outs
-    AnyRankedTensor:$result
+    AnyType:$result
   );
 
   let assemblyFormat = "$variable attr-dict `:` type($variable) `->` type($result)";
@@ -129,7 +129,7 @@
   }];
 
   let arguments = (ins
-    AnyRankedTensor:$value,
+    AnyType:$value,
     FLOW_VariableRefAttr:$variable
   );
 
@@ -147,7 +147,7 @@
   }];
 
   let arguments = (ins
-    AnyRankedTensor:$value,
+    AnyType:$value,
     FLOW_VariablePtr:$variable
   );
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index c32aba1..f0cf29e 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -27,6 +27,7 @@
         "DispatchConfig.cpp",
         "DispatchLinalgOnTensors.cpp",
         "DispatchabilityAnalysis.cpp",
+        "ExpandVariableDynamicDims.cpp",
         "FlattenTuplesInCFG.cpp",
         "FoldCompatibleDispatchRegions.cpp",
         "FormStreams.cpp",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index 4c2f4ac..8f07e4f 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -28,6 +28,7 @@
     "DispatchConfig.cpp"
     "DispatchLinalgOnTensors.cpp"
     "DispatchabilityAnalysis.cpp"
+    "ExpandVariableDynamicDims.cpp"
     "FlattenTuplesInCFG.cpp"
     "FoldCompatibleDispatchRegions.cpp"
     "FormStreams.cpp"
diff --git a/iree/compiler/Dialect/Flow/Transforms/ExpandVariableDynamicDims.cpp b/iree/compiler/Dialect/Flow/Transforms/ExpandVariableDynamicDims.cpp
new file mode 100644
index 0000000..d588938
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/ExpandVariableDynamicDims.cpp
@@ -0,0 +1,151 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+class ExpandVariableDynamicDimsPass
+    : public PassWrapper<ExpandVariableDynamicDimsPass,
+                         OperationPass<ModuleOp>> {
+ public:
+  ExpandVariableDynamicDimsPass() = default;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<IREE::Flow::FlowDialect>();
+    registry.insert<ShapeDialect>();
+  }
+
+  void runOnOperation() override {
+    auto moduleOp = getOperation();
+
+    // Gathers all of the flow.variables containing shapes.
+    SmallVector<VariableOp, 4> shapeVarOps;
+    moduleOp.walk([&](VariableOp op) {
+      if (op.type().isa<Shape::RankedShapeType>()) {
+        shapeVarOps.push_back(op);
+      }
+    });
+
+    // Split each variable into one variable per dimension.
+    for (auto shapeVarOp : shapeVarOps) {
+      expandShapeVariable(moduleOp, shapeVarOp);
+    }
+  }
+
+ private:
+  // Expands a flow.variable representing a shape with one variable per dim.
+  // Uses of the variable will be replaced with the per-dim ones.
+  void expandShapeVariable(ModuleOp moduleOp, VariableOp shapeVarOp) {
+    // Create one flow.variable per dimension (static or dynamic).
+    OpBuilder moduleBuilder(shapeVarOp);
+    auto shapeType = shapeVarOp.type().cast<Shape::RankedShapeType>();
+    SmallVector<VariableOp, 4> dimVarOps;
+    for (int i = 0; i < shapeType.getRank(); ++i) {
+      Attribute initialDimValue;
+      if (shapeType.isDimDynamic(i)) {
+        // Right now we choose zero for initial dynamic dim values but this
+        // needs to agree with bindings that may have expectations on query.
+        // 0 is at least easier to gracefully bail on when values are never
+        // overridden.
+        initialDimValue = moduleBuilder.getIndexAttr(0);
+      } else {
+        initialDimValue = moduleBuilder.getIndexAttr(shapeType.getStaticDim(i));
+      }
+      auto dimVarOp = moduleBuilder.create<VariableOp>(
+          shapeVarOp.getLoc(),
+          (shapeVarOp.getName() + "_d" + std::to_string(i)).str(),
+          /*isMutable=*/shapeType.isDimDynamic(i), moduleBuilder.getIndexType(),
+          initialDimValue);
+      dimVarOp.setPrivate();
+      dimVarOps.push_back(dimVarOp);
+    }
+
+    // Replace all uses of the single variable with the split ones.
+    replaceShapeVariableUses(moduleOp, shapeType, shapeVarOp, dimVarOps);
+
+    // Erase the original variable.
+    shapeVarOp.erase();
+  }
+
+  // Replaces uses of |shapeVarOp| in |moduleOp| with the expanded |dimVarOps|.
+  void replaceShapeVariableUses(ModuleOp moduleOp,
+                                Shape::RankedShapeType shapeType,
+                                VariableOp shapeVarOp,
+                                ArrayRef<VariableOp> dimVarOps) {
+    auto allUses = SymbolTable::getSymbolUses(shapeVarOp, moduleOp)
+                       .getValueOr(SymbolTable::UseRange({}));
+    for (auto use : allUses) {
+      if (auto loadOp = dyn_cast<VariableLoadOp>(use.getUser())) {
+        OpBuilder builder(loadOp);
+        SmallVector<Value, 4> dynamicDimValues;
+        for (int i = 0; i < shapeType.getRank(); ++i) {
+          if (!shapeType.isDimDynamic(i)) continue;
+          VariableOp dimVarOp = dimVarOps[i];
+          dynamicDimValues.push_back(builder.create<VariableLoadOp>(
+              loadOp.getLoc(), builder.getIndexType(), dimVarOp.getName()));
+        }
+        auto shapeValue = builder.create<Shape::MakeRankedShapeOp>(
+            loadOp.getLoc(), shapeType, dynamicDimValues);
+        loadOp->replaceAllUsesWith(shapeValue);
+        loadOp.erase();
+      } else if (auto storeOp = dyn_cast<VariableStoreOp>(use.getUser())) {
+        OpBuilder builder(storeOp);
+        auto shapeValue = storeOp.value();
+        for (int i = 0; i < shapeType.getRank(); ++i) {
+          if (!shapeType.isDimDynamic(i)) continue;
+          VariableOp dimVarOp = dimVarOps[i];
+          auto dynamicDimValue = builder.createOrFold<Shape::RankedDimOp>(
+              storeOp.getLoc(), shapeValue, i);
+          builder.create<VariableStoreOp>(storeOp.getLoc(), dynamicDimValue,
+                                          dimVarOp.getName());
+        }
+        storeOp.erase();
+      } else {
+        // TODO(benvanik): support indirection/addressing - should be fairly
+        // easy to do by splitting the address ops to each dim.
+        use.getUser()->emitError()
+            << "variable action on shape is not yet supported";
+        signalPassFailure();
+        return;
+      }
+    }
+  }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createExpandVariableDynamicDimsPass() {
+  return std::make_unique<ExpandVariableDynamicDimsPass>();
+}
+
+static PassRegistration<ExpandVariableDynamicDimsPass> pass(
+    "iree-flow-expand-variable-dynamic-dims",
+    "Expands !shapex.ranked_shape dynamic dimensions stored in variables.",
+    [] { return std::make_unique<ExpandVariableDynamicDimsPass>(); });
+
+}  // namespace Flow
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index fffc559..8ecaa6a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -122,6 +122,11 @@
   passManager.addNestedPass<FuncOp>(
       IREE::Flow::createMaterializeExportedReflection());
 
+  // Replaces variables with !shapex.ranked_shape types with individual
+  // variables for each dimension. This allows for constant dimensions to be
+  // DCE'd in following passes.
+  passManager.addPass(IREE::Flow::createExpandVariableDynamicDimsPass());
+
   // Materialize dynamic shapes in the IR, also expanding function signatures
   // such that:
   //   - Dynamic ranked tensors: (tensor<?x?xf32>) expands to
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index a1bc2a0..31ea5a0 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -96,6 +96,9 @@
 // expected.
 std::unique_ptr<OperationPass<FuncOp>> createMergeExportedReflection();
 
+// Expands dynamic !shapex.ranked_shape dimensions in variables.
+std::unique_ptr<OperationPass<ModuleOp>> createExpandVariableDynamicDimsPass();
+
 //===----------------------------------------------------------------------===//
 // Dispatches (flow.dispatch.region)
 //===----------------------------------------------------------------------===//
@@ -193,6 +196,7 @@
   createPostPartitioningConversionPass();
   createMaterializeExportedReflection();
   createMergeExportedReflection();
+  createExpandVariableDynamicDimsPass();
   createDispatchabilityAnalysisPass();
   createIdentifyDispatchRegionsPass();
   createIdentifyDispatchRegions2Pass();
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/expand_variable_dynamic_dims.mlir b/iree/compiler/Dialect/Flow/Transforms/test/expand_variable_dynamic_dims.mlir
new file mode 100644
index 0000000..2588b9a
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/expand_variable_dynamic_dims.mlir
@@ -0,0 +1,73 @@
+// RUN: iree-opt -split-input-file -iree-flow-expand-variable-dynamic-dims %s | IreeFileCheck %s
+
+// CHECK-NOT: flow.variable @static_var mutable
+// CHECK: flow.variable @static_var_d0 1 : index
+// CHECK: flow.variable @static_var_d1 2 : index
+// CHECK: flow.variable @static_var_d2 3 : index
+// CHECK: flow.variable @static_var_d3 4 : index
+flow.variable @static_var mutable : !shapex.ranked_shape<[1,2,3,4]>
+// CHECK-LABEL: func @static_loads
+func @static_loads() -> (index, index, index, index) {
+  // CHECK-NOT: flow.variable.load
+  // CHECK-NEXT: %[[SHAPE:.+]] = shapex.make_ranked_shape
+  %0 = flow.variable.load @static_var : !shapex.ranked_shape<[1,2,3,4]>
+  // CHECK-DAG: %[[D0:.+]] = shapex.ranked_dim %[[SHAPE]][0]
+  %1 = shapex.ranked_dim %0[0] : !shapex.ranked_shape<[1,2,3,4]> -> index
+  // CHECK-DAG: %[[D1:.+]] = shapex.ranked_dim %[[SHAPE]][1]
+  %2 = shapex.ranked_dim %0[1] : !shapex.ranked_shape<[1,2,3,4]> -> index
+  // CHECK-DAG: %[[D2:.+]] = shapex.ranked_dim %[[SHAPE]][2]
+  %3 = shapex.ranked_dim %0[2] : !shapex.ranked_shape<[1,2,3,4]> -> index
+  // CHECK-DAG: %[[D3:.+]] = shapex.ranked_dim %[[SHAPE]][3]
+  %4 = shapex.ranked_dim %0[3] : !shapex.ranked_shape<[1,2,3,4]> -> index
+  // CHECK-NEXT: return %[[D0]], %[[D1]], %[[D2]], %[[D3]]
+  return %1, %2, %3, %4 : index, index, index, index
+}
+// CHECK-LABEL: func @static_stores
+func @static_stores(%arg0 : index, %arg1 : index) {
+  // CHECK-NEXT: %[[SHAPE:.+]] = shapex.const_ranked_shape
+  %0 = shapex.const_ranked_shape : !shapex.ranked_shape<[1,2,3,4]>
+  // CHECK-NOT: flow.variable.store
+  flow.variable.store %0, @static_var : !shapex.ranked_shape<[1,2,3,4]>
+  // CHECK-NEXT: return
+  return
+}
+
+// -----
+
+// CHECK-NOT: flow.variable @dynamic_var mutable
+// CHECK: flow.variable @dynamic_var_d0 1 : index
+// CHECK: flow.variable @dynamic_var_d1 mutable 0 : index
+// CHECK: flow.variable @dynamic_var_d2 mutable 0 : index
+// CHECK: flow.variable @dynamic_var_d3 4 : index
+flow.variable @dynamic_var mutable : !shapex.ranked_shape<[1,?,?,4]>
+// CHECK-LABEL: func @dynamic_loads
+func @dynamic_loads() -> (index, index, index, index) {
+  // CHECK-NOT: flow.variable.load @dynamic_var
+  // CHECK-DAG: %[[D1:.+]] = flow.variable.load @dynamic_var_d1 : index
+  // CHECK-DAG: %[[D2:.+]] = flow.variable.load @dynamic_var_d2 : index
+  // CHECK-NEXT: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[D1]], %[[D2]] : (index, index) -> !shapex.ranked_shape<[1,?,?,4]>
+  %0 = flow.variable.load @dynamic_var : !shapex.ranked_shape<[1,?,?,4]>
+  // CHECK-DAG: = shapex.ranked_dim %[[SHAPE]][0] : !shapex.ranked_shape<[1,?,?,4]> -> index
+  %1 = shapex.ranked_dim %0[0] : !shapex.ranked_shape<[1,?,?,4]> -> index
+  // CHECK-DAG: = shapex.ranked_dim %[[SHAPE]][1] : !shapex.ranked_shape<[1,?,?,4]> -> index
+  %2 = shapex.ranked_dim %0[1] : !shapex.ranked_shape<[1,?,?,4]> -> index
+  // CHECK-DAG: = shapex.ranked_dim %[[SHAPE]][2] : !shapex.ranked_shape<[1,?,?,4]> -> index
+  %3 = shapex.ranked_dim %0[2] : !shapex.ranked_shape<[1,?,?,4]> -> index
+  // CHECK-DAG: = shapex.ranked_dim %[[SHAPE]][3] : !shapex.ranked_shape<[1,?,?,4]> -> index
+  %4 = shapex.ranked_dim %0[3] : !shapex.ranked_shape<[1,?,?,4]> -> index
+  // CHECK-NEXT: return
+  return %1, %2, %3, %4 : index, index, index, index
+}
+// CHECK-LABEL: func @dynamic_stores
+// CHECK-SAME: (%[[D1:.+]]: index, %[[D2:.+]]: index)
+func @dynamic_stores(%arg0 : index, %arg1 : index) {
+  // CHECK-NEXT: %[[SHAPE:.+]] = shapex.make_ranked_shape %arg0, %arg1
+  %0 = shapex.make_ranked_shape %arg0, %arg1 : (index, index) -> !shapex.ranked_shape<[1,?,?,4]>
+  // CHECK-NEXT: %[[D1:.+]] = shapex.ranked_dim %[[SHAPE]][1] : !shapex.ranked_shape<[1,?,?,4]> -> index
+  // CHECK-NEXT: flow.variable.store %[[D1]], @dynamic_var_d1 : index
+  // CHECK-NEXT: %[[D2:.+]] = shapex.ranked_dim %[[SHAPE]][2] : !shapex.ranked_shape<[1,?,?,4]> -> index
+  // CHECK-NEXT: flow.variable.store %[[D2]], @dynamic_var_d2 : index
+  flow.variable.store %0, @dynamic_var : !shapex.ranked_shape<[1,?,?,4]>
+  // CHECK-NEXT: return
+  return
+}
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp
index 9928153..4315fc5 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertVariableOps.cpp
@@ -38,14 +38,24 @@
           op.initial_value(), llvm::to_vector<4>(op->getDialectAttrs()));
       return success();
     } else if (convertedType.isInteger(32)) {
+      auto convertedValue =
+          op.initial_value().hasValue()
+              ? rewriter.getI32IntegerAttr(static_cast<int32_t>(
+                    op.initial_value().getValue().cast<IntegerAttr>().getInt()))
+              : Attribute{};
       rewriter.replaceOpWithNewOp<IREE::VM::GlobalI32Op>(
           op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
-          op.initial_value(), llvm::to_vector<4>(op->getDialectAttrs()));
+          convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
       return success();
     } else if (convertedType.isInteger(64)) {
+      auto convertedValue =
+          op.initial_value().hasValue()
+              ? rewriter.getI64IntegerAttr(
+                    op.initial_value().getValue().cast<IntegerAttr>().getInt())
+              : Attribute{};
       rewriter.replaceOpWithNewOp<IREE::VM::GlobalI64Op>(
           op, op.sym_name(), op.is_mutable(), convertedType, op.initializer(),
-          op.initial_value(), llvm::to_vector<4>(op->getDialectAttrs()));
+          convertedValue, llvm::to_vector<4>(op->getDialectAttrs()));
       return success();
     }
     return op.emitOpError("unsupported variable type");