Add hoist support for index type (#18303)

https://github.com/iree-org/iree/issues/18232

Signed-off-by: jinchen62 <jinchenye62@gmail.com>
diff --git a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
index 4fd578a..9a4d89a 100644
--- a/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
+++ b/compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
@@ -628,6 +628,7 @@
         requestedTargetDevice);
     compileOptions->targetOptions.f32Extension = true;
     compileOptions->targetOptions.f64Extension = true;
+    compileOptions->targetOptions.indexBits = 64;
     compileOptions->targetOptions.truncateUnsupportedFloats = false;
     compileOptions->inputOptions.demoteF64ToF32 = false;
     if (requestedTargetDevice == "vmvx" || !hasRequestedTargetDevice) {
@@ -677,14 +678,15 @@
     s.addScalarType(b.getIntegerType(16));
     s.addScalarType(b.getIntegerType(32));
     s.addScalarType(b.getIntegerType(64));
+    s.addScalarType(b.getIndexType());
     s.addScalarType(b.getF32Type());
 
     s.addElementType(b.getIntegerType(1));
-
     s.addElementType(b.getIntegerType(8));
     s.addElementType(b.getIntegerType(16));
     s.addElementType(b.getIntegerType(32));
     s.addElementType(b.getIntegerType(64));
+    s.addElementType(b.getIndexType());
     s.addElementType(b.getF32Type());
     if (requestedTargetDevice != "vmvx" && hasRequestedTargetDevice) {
       // The full compilers support additional types.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel
index 0b2f50f..27f9a8e 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel
+++ b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/BUILD.bazel
@@ -38,6 +38,7 @@
         "//compiler/src/iree/compiler/Dialect/Flow/IR",
         "//compiler/src/iree/compiler/Dialect/Util/IR",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:IR",
     ],
 )
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt
index 0b5911b..657d9aa 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/CMakeLists.txt
@@ -32,6 +32,7 @@
     "HoistableTypeInterface.cpp"
   DEPS
     LLVMSupport
+    MLIRArithDialect
     MLIRIR
     iree::compiler::Dialect::Flow::IR
     iree::compiler::Dialect::Util::IR
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp
index 76c90cf..f5a3135 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Interfaces/HoistableTypeInterface.cpp
@@ -9,6 +9,7 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
 #include "llvm/Support/MathExtras.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/BuiltinTypes.h"
 
 namespace mlir::iree_compiler {
@@ -85,6 +86,36 @@
   }
 };
 
+struct HoistableIndexTypeInterface
+    : public IREE::Util::HoistableTypeInterface::ExternalModel<
+          HoistableIndexTypeInterface, IndexType> {
+  bool isHoistableType(Type type) const { return true; }
+  bool isHoistableLeafType(Type type) const { return true; }
+  Type getPreferredStorageType(Type type) const {
+    // Conservatively enforce 64 bit indices for
+    // (potentially constant evaluated) hoisted globals.
+    return IntegerType::get(type.getContext(), 64);
+  }
+  static Value encodeStorageType(OpBuilder &builder, Location loc,
+                                 Type storageType, Value init) {
+    auto storageIndexType = dyn_cast<IntegerType>(storageType);
+    if (!storageIndexType || init.getType() == storageIndexType ||
+        !isa<IndexType>(init.getType())) {
+      return init;
+    }
+    return builder.create<arith::IndexCastOp>(loc, storageType, init);
+  }
+  static Value decodeStorageType(OpBuilder &builder, Location loc,
+                                 Type originalType, Value loadedGlobal) {
+    auto originalIndexType = dyn_cast<IndexType>(originalType);
+    if (!originalIndexType || loadedGlobal.getType() == originalIndexType ||
+        !isa<IntegerType>(loadedGlobal.getType())) {
+      return loadedGlobal;
+    }
+    return builder.create<arith::IndexCastOp>(loc, originalType, loadedGlobal);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // IREE specific post analysis transformations.
 //===----------------------------------------------------------------------===//
@@ -93,6 +124,7 @@
   // Register hoistable type interfaces for builtin types.
   registry.addExtension(+[](MLIRContext *ctx) {
     RankedTensorType::attachInterface<HoistableTensorTypeInterface>(*ctx);
+    IndexType::attachInterface<HoistableIndexTypeInterface>(*ctx);
   });
 }
 
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
index e289f07..67e6315 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/hoist_into_globals.mlir
@@ -156,3 +156,26 @@
     util.return %1 : tensor<i32>
   }
 }
+
+// -----
+
+// CHECK-LABEL: @hoist_index
+module @hoist_index {
+  // CHECK: util.global private @[[HOISTED:.*]] : i64
+  // CHECK: util.initializer
+  // CHECK:   %[[C0:.*]] = arith.constant 0 : index
+  // CHECK:   %[[CEXPR:.*]] = "iree_unregistered.const_expr"(%[[C0]])
+  // CHECK:   %[[CAST:.*]] = arith.index_cast %[[CEXPR]] : index to i64
+  // CHECK:   util.global.store %[[CAST]], @[[HOISTED]] : i64
+  // CHECK:   util.return
+
+  // CHECK: util.func public @main() -> index
+  // CHECK:   %[[GLOBAL_LD:.*]] = util.global.load immutable @[[HOISTED]] : i64
+  // CHECK:   %[[ORIG_VAL:.*]] = arith.index_cast %[[GLOBAL_LD]] : i64 to index
+  // CHECK:   util.return %[[ORIG_VAL]]
+  util.func public @main() -> (index) {
+    %0 = arith.constant 0 : index
+    %1 = "iree_unregistered.const_expr"(%0) : (index) -> index
+    util.return %1 : index
+  }
+}