Add hoist support for index type (#18303)
https://github.com/iree-org/iree/issues/18232
Signed-off-by: jinchen62 <jinchenye62@gmail.com>
diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
index 4fd578a..9a4d89a 100644
--- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
+++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
@@ -628,6 +628,7 @@
requestedTargetDevice);
compileOptions->targetOptions.f32Extension = true;
compileOptions->targetOptions.f64Extension = true;
+ compileOptions->targetOptions.indexBits = 64;
compileOptions->targetOptions.truncateUnsupportedFloats = false;
compileOptions->inputOptions.demoteF64ToF32 = false;
if (requestedTargetDevice == "vmvx" || !hasRequestedTargetDevice) {
@@ -677,14 +678,15 @@
s.addScalarType(b.getIntegerType(16));
s.addScalarType(b.getIntegerType(32));
s.addScalarType(b.getIntegerType(64));
+ s.addScalarType(b.getIndexType());
s.addScalarType(b.getF32Type());
s.addElementType(b.getIntegerType(1));
-
s.addElementType(b.getIntegerType(8));
s.addElementType(b.getIntegerType(16));
s.addElementType(b.getIntegerType(32));
s.addElementType(b.getIntegerType(64));
+ s.addElementType(b.getIndexType());
s.addElementType(b.getF32Type());
if (requestedTargetDevice != "vmvx" && hasRequestedTargetDevice) {
// The full compilers support additional types.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel
index 0b2f50f..27f9a8e 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel
@@ -38,6 +38,7 @@
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
],
)
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt
index 0b5911b..657d9aa 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt
@@ -32,6 +32,7 @@
"HoistableTypeInterface.cpp"
DEPS
LLVMSupport
+ MLIRArithDialect
MLIRIR
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Util::IR
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp
index 76c90cf..f5a3135 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp
@@ -9,6 +9,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "llvm/Support/MathExtras.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinTypes.h"
namespace mlir::iree_compiler {
@@ -85,6 +86,36 @@
}
};
+struct HoistableIndexTypeInterface
+ : public IREE::Util::HoistableTypeInterface::ExternalModel<
+ HoistableIndexTypeInterface, IndexType> {
+ bool isHoistableType(Type type) const { return true; }
+ bool isHoistableLeafType(Type type) const { return true; }
+ Type getPreferredStorageType(Type type) const {
+ // Conservatively enforce 64 bit indices for
+ // (potentially constant evaluated) hoisted globals.
+ return IntegerType::get(type.getContext(), 64);
+ }
+ static Value encodeStorageType(OpBuilder &builder, Location loc,
+ Type storageType, Value init) {
+ auto storageIndexType = dyn_cast<IntegerType>(storageType);
+ if (!storageIndexType || init.getType() == storageIndexType ||
+ !isa<IndexType>(init.getType())) {
+ return init;
+ }
+ return builder.create<arith::IndexCastOp>(loc, storageType, init);
+ }
+ static Value decodeStorageType(OpBuilder &builder, Location loc,
+ Type originalType, Value loadedGlobal) {
+ auto originalIndexType = dyn_cast<IndexType>(originalType);
+ if (!originalIndexType || loadedGlobal.getType() == originalIndexType ||
+ !isa<IntegerType>(loadedGlobal.getType())) {
+ return loadedGlobal;
+ }
+ return builder.create<arith::IndexCastOp>(loc, originalType, loadedGlobal);
+ }
+};
+
//===----------------------------------------------------------------------===//
// IREE specific post analysis transformations.
//===----------------------------------------------------------------------===//
@@ -93,6 +124,7 @@
// Register hoistable type interfaces for builtin types.
registry.addExtension(+[](MLIRContext *ctx) {
RankedTensorType::attachInterface<HoistableTensorTypeInterface>(*ctx);
+ IndexType::attachInterface<HoistableIndexTypeInterface>(*ctx);
});
}
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
index e289f07..67e6315 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
@@ -156,3 +156,26 @@
util.return %1 : tensor<i32>
}
}
+
+// -----
+
+// CHECK-LABEL: @hoist_index
+module @hoist_index {
+ // CHECK: util.global private @[[HOISTED:.*]] : i64
+ // CHECK: util.initializer
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CEXPR:.*]] = "iree_unregistered.const_expr"(%[[C0]])
+ // CHECK: %[[CAST:.*]] = arith.index_cast %[[CEXPR]] : index to i64
+ // CHECK: util.global.store %[[CAST]], @[[HOISTED]] : i64
+ // CHECK: util.return
+
+ // CHECK: util.func public @main() -> index
+ // CHECK: %[[GLOBAL_LD:.*]] = util.global.load immutable @[[HOISTED]] : i64
+ // CHECK: %[[ORIG_VAL:.*]] = arith.index_cast %[[GLOBAL_LD]] : i64 to index
+ // CHECK: util.return %[[ORIG_VAL]]
+ util.func public @main() -> (index) {
+ %0 = arith.constant 0 : index
+ %1 = "iree_unregistered.const_expr"(%0) : (index) -> index
+ util.return %1 : index
+ }
+}