Generalizing MHLO type legalization to support util.global (and others). (#6803)
diff --git a/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp b/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
index ee7ac08..85a700e 100644
--- a/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
+++ b/iree/compiler/InputConversion/MHLO/LegalizeInputTypes.cpp
@@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/Transforms/TypeConverter.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/InputConversion/MHLO/PassDetail.h"
#include "iree/compiler/InputConversion/MHLO/Passes.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
@@ -26,10 +27,12 @@
namespace mlir {
namespace iree_compiler {
-namespace {
+static Attribute convertAttribute(Location loc, Attribute value,
+ FlowTypeConverter &typeConverter) {
+ if (auto oldTypeAttr = value.dyn_cast<TypeAttr>()) {
+ return TypeAttr::get(typeConverter.convertType(oldTypeAttr.getValue()));
+ }
-Attribute convertAttribute(Location loc, Attribute value,
- FlowTypeConverter &typeConverter) {
auto newType = typeConverter.convertType(value.getType());
if (value.getType() == newType) {
return value;
@@ -68,14 +71,14 @@
return {};
}
-LogicalResult convertRegion(Region &oldRegion, Region &newRegion,
- FlowTypeConverter &typeConverter,
- BlockAndValueMapping &mapping);
+static LogicalResult convertRegion(Region &oldRegion, Region &newRegion,
+ FlowTypeConverter &typeConverter,
+ BlockAndValueMapping &mapping);
-LogicalResult convertOperation(Operation *oldOp,
- FlowTypeConverter &typeConverter,
- BlockAndValueMapping &mapping,
- OpBuilder &builder) {
+static LogicalResult convertOperation(Operation *oldOp,
+ FlowTypeConverter &typeConverter,
+ BlockAndValueMapping &mapping,
+ OpBuilder &builder) {
if (llvm::isa<linalg::LinalgOp>(oldOp)) {
// Currently we assume all Linalg structured ops only contain valid types.
// We allow to convert non-structured operation like
@@ -92,14 +95,17 @@
}
}
- if (llvm::isa<mlir::ConstantOp>(oldOp) || llvm::isa<mhlo::ConstOp>(oldOp)) {
- // Deal with all value-based constant ops generically.
- Attribute oldValue = oldOp->getAttr("value");
- auto newValue = convertAttribute(oldOp->getLoc(), oldValue, typeConverter);
- if (!newValue) {
- return failure();
+ if (llvm::isa<mlir::ConstantOp>(oldOp) || llvm::isa<mhlo::ConstOp>(oldOp) ||
+ llvm::isa<IREE::Util::GlobalOp>(oldOp)) {
+ for (auto attr : oldOp->getAttrs()) {
+ auto newAttr =
+ convertAttribute(oldOp->getLoc(), attr.second, typeConverter);
+ if (!newAttr) {
+ return oldOp->emitOpError()
+ << "failed to convert attribute " << attr.first;
+ }
+ state.addAttribute(attr.first, newAttr);
}
- state.addAttribute("value", newValue);
} else {
state.attributes = llvm::to_vector<4>(oldOp->getAttrs());
}
@@ -141,9 +147,9 @@
return success();
}
-LogicalResult convertBlock(Block &oldBlock, Block &newBlock,
- FlowTypeConverter &typeConverter,
- BlockAndValueMapping &mapping) {
+static LogicalResult convertBlock(Block &oldBlock, Block &newBlock,
+ FlowTypeConverter &typeConverter,
+ BlockAndValueMapping &mapping) {
OpBuilder builder(oldBlock.getParent()->getContext());
builder.setInsertionPointToEnd(&newBlock);
for (auto &oldOp : oldBlock) {
@@ -154,9 +160,9 @@
return success();
}
-LogicalResult convertRegion(Region &oldRegion, Region &newRegion,
- FlowTypeConverter &typeConverter,
- BlockAndValueMapping &mapping) {
+static LogicalResult convertRegion(Region &oldRegion, Region &newRegion,
+ FlowTypeConverter &typeConverter,
+ BlockAndValueMapping &mapping) {
OpBuilder builder(oldRegion.getContext());
for (auto &oldBlock : oldRegion) {
auto &newBlock = *builder.createBlock(&newRegion);
@@ -181,7 +187,36 @@
return success();
}
-} // namespace
+static LogicalResult convertFunc(mlir::FuncOp oldFuncOp,
+ FlowTypeConverter &typeConverter,
+ OpBuilder &moduleBuilder) {
+ auto oldType = oldFuncOp.getType();
+ TypeConverter::SignatureConversion signature(oldType.getNumInputs());
+ for (unsigned i = 0, e = oldType.getNumInputs(); i != e; ++i) {
+ if (failed(typeConverter.convertSignatureArg(i, oldType.getInput(i),
+ signature))) {
+ return oldFuncOp.emitOpError()
+ << "unable to legalize type of input " << i;
+ }
+ }
+ SmallVector<Type, 1> convertedResults;
+ if (failed(
+ typeConverter.convertTypes(oldType.getResults(), convertedResults))) {
+ return oldFuncOp.emitOpError() << "unable to legalize result types";
+ }
+
+ auto newFuncOp = cast<FuncOp>(moduleBuilder.cloneWithoutRegions(*oldFuncOp));
+ newFuncOp.setType(FunctionType::get(
+ oldFuncOp.getContext(), signature.getConvertedTypes(), convertedResults));
+
+ BlockAndValueMapping mapping;
+ if (failed(convertRegion(oldFuncOp.getBody(), newFuncOp.getBody(),
+ typeConverter, mapping))) {
+ return failure();
+ }
+
+ return success();
+}
class LegalizeInputTypesPass
: public LegalizeInputTypesBase<LegalizeInputTypesPass> {
@@ -190,39 +225,24 @@
auto moduleOp = getOperation();
FlowTypeConverter typeConverter;
- auto oldFuncOps = llvm::to_vector<16>(moduleOp.getOps<FuncOp>());
- for (auto oldFuncOp : oldFuncOps) {
+ auto oldOps = llvm::to_vector<4>(llvm::map_range(
+ moduleOp.body().getOps(), [](Operation &op) { return &op; }));
+ for (auto *oldOp : oldOps) {
OpBuilder moduleBuilder(moduleOp);
- moduleBuilder.setInsertionPoint(oldFuncOp);
-
- auto oldType = oldFuncOp.getType();
- TypeConverter::SignatureConversion signature(oldType.getNumInputs());
- for (unsigned i = 0, e = oldType.getNumInputs(); i != e; ++i) {
- if (failed(typeConverter.convertSignatureArg(i, oldType.getInput(i),
- signature))) {
- oldFuncOp.emitOpError() << "unable to legalize type of input " << i;
+ moduleBuilder.setInsertionPoint(oldOp);
+ if (auto oldFuncOp = dyn_cast<mlir::FuncOp>(oldOp)) {
+ if (failed(convertFunc(oldFuncOp, typeConverter, moduleBuilder))) {
return signalPassFailure();
}
+ oldOp->erase();
+ } else {
+ BlockAndValueMapping mapping;
+ if (failed(convertOperation(oldOp, typeConverter, mapping,
+ moduleBuilder))) {
+ return signalPassFailure();
+ }
+ oldOp->erase();
}
- SmallVector<Type, 1> convertedResults;
- if (failed(typeConverter.convertTypes(oldType.getResults(),
- convertedResults))) {
- oldFuncOp.emitOpError() << "unable to legalize result types";
- return signalPassFailure();
- }
-
- auto newFuncOp =
- cast<FuncOp>(moduleBuilder.cloneWithoutRegions(*oldFuncOp));
- newFuncOp.setType(FunctionType::get(
- &getContext(), signature.getConvertedTypes(), convertedResults));
-
- BlockAndValueMapping mapping;
- if (failed(convertRegion(oldFuncOp.getBody(), newFuncOp.getBody(),
- typeConverter, mapping))) {
- return signalPassFailure();
- }
-
- oldFuncOp.erase();
}
}
};
diff --git a/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir b/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
index f1e06b4..8445929 100644
--- a/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/legalize_input_types.mlir
@@ -154,3 +154,23 @@
%0 = linalg.tensor_expand_shape %arg0 [[0, 1]] : tensor<9xi64> into tensor<1x9xi64>
return %0 : tensor<1x9xi64>
}
+
+// -----
+
+// CHECK: util.global public mutable @[[VAR:.+]] = dense<0> : tensor<i32>
+// CHECK: util.global.load @[[VAR]]
+// CHECK: util.global.store %{{.+}}, @[[VAR]]
+util.global mutable @readwritevar = dense<0> : tensor<i64>
+builtin.func @foo(%arg0 : tensor<i64>) {
+ %0 = util.global.load @readwritevar : tensor<i64>
+ %1 = chlo.broadcast_add %0, %arg0 : (tensor<i64>, tensor<i64>) -> tensor<i64>
+ util.global.store %1, @readwritevar : tensor<i64>
+ return
+}
+
+// -----
+
+// CHECK: util.global private @{{.+}} initializer(@initializer) : tensor<4xi32>
+util.global private @v_initializer initializer(@initializer) : tensor<4xi64>
+// CHECK: func private @initializer() -> tensor<4xi32>
+func private @initializer() -> tensor<4xi64>