Generalizing MHLO type legalization to support util.global (and others). (#6803)

diff --git a/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp b/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
index ee7ac08..85a700e 100644
--- a/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
+++ b/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
@@ -5,6 +5,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 #include "iree/compiler/Dialect/Flow/Transforms/TypeConverter.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
 #include "iree/compiler/InputConversion/MHLO/PassDetail.h"
 #include "iree/compiler/InputConversion/MHLO/Passes.h"
 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
@@ -26,10 +27,12 @@
 namespace mlir {
 namespace iree_compiler {
 
-namespace {
+static Attribute convertAttribute(Location loc, Attribute value,
+                                  FlowTypeConverter &typeConverter) {
+  if (auto oldTypeAttr = value.dyn_cast<TypeAttr>()) {
+    return TypeAttr::get(typeConverter.convertType(oldTypeAttr.getValue()));
+  }
 
-Attribute convertAttribute(Location loc, Attribute value,
-                           FlowTypeConverter &typeConverter) {
   auto newType = typeConverter.convertType(value.getType());
   if (value.getType() == newType) {
     return value;
@@ -68,14 +71,14 @@
   return {};
 }
 
-LogicalResult convertRegion(Region &oldRegion, Region &newRegion,
-                            FlowTypeConverter &typeConverter,
-                            BlockAndValueMapping &mapping);
+static LogicalResult convertRegion(Region &oldRegion, Region &newRegion,
+                                   FlowTypeConverter &typeConverter,
+                                   BlockAndValueMapping &mapping);
 
-LogicalResult convertOperation(Operation *oldOp,
-                               FlowTypeConverter &typeConverter,
-                               BlockAndValueMapping &mapping,
-                               OpBuilder &builder) {
+static LogicalResult convertOperation(Operation *oldOp,
+                                      FlowTypeConverter &typeConverter,
+                                      BlockAndValueMapping &mapping,
+                                      OpBuilder &builder) {
   if (llvm::isa<linalg::LinalgOp>(oldOp)) {
     // Currently we assume all Linalg structured ops only contain valid types.
     // We allow to convert non-structured operation like
@@ -92,14 +95,17 @@
     }
   }
 
-  if (llvm::isa<mlir::ConstantOp>(oldOp) || llvm::isa<mhlo::ConstOp>(oldOp)) {
-    // Deal with all value-based constant ops generically.
-    Attribute oldValue = oldOp->getAttr("value");
-    auto newValue = convertAttribute(oldOp->getLoc(), oldValue, typeConverter);
-    if (!newValue) {
-      return failure();
+  if (llvm::isa<mlir::ConstantOp>(oldOp) || llvm::isa<mhlo::ConstOp>(oldOp) ||
+      llvm::isa<IREE::Util::GlobalOp>(oldOp)) {
+    for (auto attr : oldOp->getAttrs()) {
+      auto newAttr =
+          convertAttribute(oldOp->getLoc(), attr.second, typeConverter);
+      if (!newAttr) {
+        return oldOp->emitOpError()
+               << "failed to convert attribute " << attr.first;
+      }
+      state.addAttribute(attr.first, newAttr);
     }
-    state.addAttribute("value", newValue);
   } else {
     state.attributes = llvm::to_vector<4>(oldOp->getAttrs());
   }
@@ -141,9 +147,9 @@
   return success();
 }
 
-LogicalResult convertBlock(Block &oldBlock, Block &newBlock,
-                           FlowTypeConverter &typeConverter,
-                           BlockAndValueMapping &mapping) {
+static LogicalResult convertBlock(Block &oldBlock, Block &newBlock,
+                                  FlowTypeConverter &typeConverter,
+                                  BlockAndValueMapping &mapping) {
   OpBuilder builder(oldBlock.getParent()->getContext());
   builder.setInsertionPointToEnd(&newBlock);
   for (auto &oldOp : oldBlock) {
@@ -154,9 +160,9 @@
   return success();
 }
 
-LogicalResult convertRegion(Region &oldRegion, Region &newRegion,
-                            FlowTypeConverter &typeConverter,
-                            BlockAndValueMapping &mapping) {
+static LogicalResult convertRegion(Region &oldRegion, Region &newRegion,
+                                   FlowTypeConverter &typeConverter,
+                                   BlockAndValueMapping &mapping) {
   OpBuilder builder(oldRegion.getContext());
   for (auto &oldBlock : oldRegion) {
     auto &newBlock = *builder.createBlock(&newRegion);
@@ -181,7 +187,36 @@
   return success();
 }
 
-}  // namespace
+static LogicalResult convertFunc(mlir::FuncOp oldFuncOp,
+                                 FlowTypeConverter &typeConverter,
+                                 OpBuilder &moduleBuilder) {
+  auto oldType = oldFuncOp.getType();
+  TypeConverter::SignatureConversion signature(oldType.getNumInputs());
+  for (unsigned i = 0, e = oldType.getNumInputs(); i != e; ++i) {
+    if (failed(typeConverter.convertSignatureArg(i, oldType.getInput(i),
+                                                 signature))) {
+      return oldFuncOp.emitOpError()
+             << "unable to legalize type of input " << i;
+    }
+  }
+  SmallVector<Type, 1> convertedResults;
+  if (failed(
+          typeConverter.convertTypes(oldType.getResults(), convertedResults))) {
+    return oldFuncOp.emitOpError() << "unable to legalize result types";
+  }
+
+  auto newFuncOp = cast<FuncOp>(moduleBuilder.cloneWithoutRegions(*oldFuncOp));
+  newFuncOp.setType(FunctionType::get(
+      oldFuncOp.getContext(), signature.getConvertedTypes(), convertedResults));
+
+  BlockAndValueMapping mapping;
+  if (failed(convertRegion(oldFuncOp.getBody(), newFuncOp.getBody(),
+                           typeConverter, mapping))) {
+    return failure();
+  }
+
+  return success();
+}
 
 class LegalizeInputTypesPass
     : public LegalizeInputTypesBase<LegalizeInputTypesPass> {
@@ -190,39 +225,24 @@
     auto moduleOp = getOperation();
     FlowTypeConverter typeConverter;
 
-    auto oldFuncOps = llvm::to_vector<16>(moduleOp.getOps<FuncOp>());
-    for (auto oldFuncOp : oldFuncOps) {
+    auto oldOps = llvm::to_vector<4>(llvm::map_range(
+        moduleOp.body().getOps(), [](Operation &op) { return &op; }));
+    for (auto *oldOp : oldOps) {
       OpBuilder moduleBuilder(moduleOp);
-      moduleBuilder.setInsertionPoint(oldFuncOp);
-
-      auto oldType = oldFuncOp.getType();
-      TypeConverter::SignatureConversion signature(oldType.getNumInputs());
-      for (unsigned i = 0, e = oldType.getNumInputs(); i != e; ++i) {
-        if (failed(typeConverter.convertSignatureArg(i, oldType.getInput(i),
-                                                     signature))) {
-          oldFuncOp.emitOpError() << "unable to legalize type of input " << i;
+      moduleBuilder.setInsertionPoint(oldOp);
+      if (auto oldFuncOp = dyn_cast<mlir::FuncOp>(oldOp)) {
+        if (failed(convertFunc(oldFuncOp, typeConverter, moduleBuilder))) {
           return signalPassFailure();
         }
+        oldOp->erase();
+      } else {
+        BlockAndValueMapping mapping;
+        if (failed(convertOperation(oldOp, typeConverter, mapping,
+                                    moduleBuilder))) {
+          return signalPassFailure();
+        }
+        oldOp->erase();
       }
-      SmallVector<Type, 1> convertedResults;
-      if (failed(typeConverter.convertTypes(oldType.getResults(),
-                                            convertedResults))) {
-        oldFuncOp.emitOpError() << "unable to legalize result types";
-        return signalPassFailure();
-      }
-
-      auto newFuncOp =
-          cast<FuncOp>(moduleBuilder.cloneWithoutRegions(*oldFuncOp));
-      newFuncOp.setType(FunctionType::get(
-          &getContext(), signature.getConvertedTypes(), convertedResults));
-
-      BlockAndValueMapping mapping;
-      if (failed(convertRegion(oldFuncOp.getBody(), newFuncOp.getBody(),
-                               typeConverter, mapping))) {
-        return signalPassFailure();
-      }
-
-      oldFuncOp.erase();
     }
   }
 };
diff --git a/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir b/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
index f1e06b4..8445929 100644
--- a/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
@@ -154,3 +154,23 @@
   %0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<9xi64> into tensor<1x9xi64>
   return %0 : tensor<1x9xi64>
 }
+
+// -----
+
+// CHECK: util.global public mutable @[[VAR:.+]] = dense<0> : tensor<i32>
+// CHECK: util.global.load @[[VAR]]
+// CHECK: util.global.store %{{.+}}, @[[VAR]]
+util.global mutable @readwritevar = dense<0> : tensor<i64>
+builtin.func @foo(%arg0 : tensor<i64>) {
+  %0 = util.global.load @readwritevar : tensor<i64>
+  %1 = chlo.broadcast_add %0, %arg0 : (tensor<i64>, tensor<i64>) -> tensor<i64>
+  util.global.store %1, @readwritevar : tensor<i64>
+  return
+}
+
+// -----
+
+// CHECK: util.global private @{{.+}} initializer(@initializer) : tensor<4xi32>
+util.global private @v_initializer initializer(@initializer) : tensor<4xi64>
+// CHECK: func private @initializer() -> tensor<4xi32>
+func private @initializer() -> tensor<4xi64>