Adding an IntegerSet utility and making PackConstants use it. (#18013)

This is an extension of IndexSet to integer types so that it can be used
to elide more values. The PackConstants pass produces a large number of
duplicate values for parameters that can nearly all be elided in our
current common cases (all parameter slices at offset 0, etc). In some
large models with lots of parameters this saves 10k+ redundant values
from being created/needing to be cleaned up.
diff --git a/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp b/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp
index 2150afa..21a0086 100644
--- a/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/ConvertCollectives.cpp
@@ -11,7 +11,7 @@
 #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BuiltinAttributes.h"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index f090386..588f560 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -12,7 +12,7 @@
 #include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
 #include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Support/Path.h"
 #include "llvm/Support/ToolOutputFile.h"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
index 0ca39e4..42b8424 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
@@ -11,7 +11,7 @@
 
 #include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSwitch.h"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp
index 7cf07ac..db04234 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/LayoutSlices.cpp
@@ -13,7 +13,7 @@
 #include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
 #include "iree/compiler/Dialect/Util/IR/UtilOps.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/AsmState.h"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
index 0ba184a..c38f4dc 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/PackConstants.cpp
@@ -11,7 +11,7 @@
 #include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
 #include "iree/compiler/Dialect/Util/IR/UtilOps.h"
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/AsmState.h"
@@ -253,6 +253,7 @@
 };
 
 static ParameterSlice getParameterSlice(Location loc, Attribute value,
+                                        IntegerSet<int64_t> &i64Set,
                                         IndexSet &indexSet,
                                         OpBuilder &builder) {
   auto parameterAttr = cast<IREE::Stream::NamedParameterAttr>(value);
@@ -260,17 +261,18 @@
   Value sourceLength;
   if (auto configAttr = parameterAttr.getConfig()) {
     if (auto offsetAttr = configAttr.getAs<IntegerAttr>("offset")) {
-      sourceOffset =
-          builder.create<arith::ConstantIntOp>(loc, offsetAttr.getInt(), 64);
+      sourceOffset = i64Set.get(offsetAttr.getValue());
     }
     if (auto lengthAttr = configAttr.getAs<IntegerAttr>("length")) {
       sourceLength = indexSet.get(lengthAttr.getInt());
     }
   }
-  if (!sourceOffset)
-    sourceOffset = builder.create<arith::ConstantIntOp>(loc, 0, 64);
-  if (!sourceLength)
+  if (!sourceOffset) {
+    sourceOffset = i64Set.get(0);
+  }
+  if (!sourceLength) {
     sourceLength = indexSet.get(parameterAttr.getStorageSize());
+  }
   return ParameterSlice{parameterAttr, sourceOffset, sourceLength};
 }
 
@@ -278,7 +280,8 @@
                                 IREE::Stream::AffinityAttr affinityAttr,
                                 Type targetType, StringAttr scope,
                                 ArrayRef<StorageResource *> storageResources,
-                                IndexSet &indexSet, OpBuilder &builder) {
+                                IntegerSet<int64_t> &i64Set, IndexSet &indexSet,
+                                OpBuilder &builder) {
   SmallVector<Location> spanLocs;
   SmallVector<Attribute> sourceKeys;
   SmallVector<Value> sourceOffsets;
@@ -289,8 +292,8 @@
            "expected single span per resource for load");
     for (auto &packedSpan : storageResource->spans) {
       auto spanLoc = packedSpan.slice.result.getLoc();
-      auto parameterSlice =
-          getParameterSlice(spanLoc, packedSpan.slice.value, indexSet, builder);
+      auto parameterSlice = getParameterSlice(spanLoc, packedSpan.slice.value,
+                                              i64Set, indexSet, builder);
       spanLocs.push_back(spanLoc);
       sourceKeys.push_back(parameterSlice.parameterAttr.getKey());
       sourceOffsets.push_back(parameterSlice.sourceOffset);
@@ -324,11 +327,10 @@
   return loadOp.getResultTimepoint();
 }
 
-static TimepointResource
-buildParameterGather(Location loc, Value awaitTimepoint,
-                     IREE::Stream::AffinityAttr affinityAttr, Type targetType,
-                     Value targetSize, MutableArrayRef<PackedSpan> packedSpans,
-                     IndexSet &indexSet, OpBuilder &builder) {
+static TimepointResource buildParameterGather(
+    Location loc, Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr,
+    Type targetType, Value targetSize, MutableArrayRef<PackedSpan> packedSpans,
+    IntegerSet<int64_t> &i64Set, IndexSet &indexSet, OpBuilder &builder) {
   // Allocate the resulting storage resource of the final resource type.
   auto allocOp = builder.create<IREE::Stream::ResourceAllocOp>(
       loc, targetType, targetSize,
@@ -352,8 +354,8 @@
     SmallVector<Value> targetLengths;
     sourceKeys.reserve(packedSpans.size());
     for (auto &packedSpan : packedSpans) {
-      auto parameterSlice =
-          getParameterSlice(loc, packedSpan.slice.value, indexSet, builder);
+      auto parameterSlice = getParameterSlice(loc, packedSpan.slice.value,
+                                              i64Set, indexSet, builder);
       sourceKeys.push_back(parameterSlice.parameterAttr.getKey());
       sourceOffsets.push_back(parameterSlice.sourceOffset);
       targetOffsets.push_back(indexSet.get(packedSpan.offset));
@@ -384,13 +386,11 @@
                            allocOp.getResultSize(0)};
 }
 
-static TimepointResource buildFileRead(Location loc, Value awaitTimepoint,
-                                       IREE::Stream::AffinityAttr affinityAttr,
-                                       IREE::Stream::ResourceType resourceType,
-                                       Value storageResourceSize,
-                                       Value storageBuffer,
-                                       Value storageBufferSize,
-                                       IndexSet &indexSet, OpBuilder &builder) {
+static TimepointResource buildFileRead(
+    Location loc, Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr,
+    IREE::Stream::ResourceType resourceType, Value storageResourceSize,
+    Value storageBuffer, Value storageBufferSize, IntegerSet<int64_t> &i64Set,
+    IndexSet &indexSet, OpBuilder &builder) {
   // Allocate the resulting storage resource of the final resource type.
   auto allocOp = builder.create<IREE::Stream::ResourceAllocOp>(
       loc, resourceType, storageResourceSize,
@@ -402,7 +402,7 @@
       storageResourceSize, affinityAttr);
 
   // Issue asynchronous file read into the buffer.
-  auto zeroI64 = builder.create<arith::ConstantIntOp>(loc, 0, 64);
+  auto zeroI64 = i64Set.get(0);
   auto readOp = builder.create<IREE::Stream::FileReadOp>(
       loc, fileOp.getResult(), zeroI64, allocOp.getResult(),
       allocOp.getResultSize(0), indexSet.get(0), storageResourceSize,
@@ -419,8 +419,8 @@
 static TimepointResource buildTryMapConstantResource(
     Location loc, Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr,
     IREE::Stream::ResourceType resourceType, Value storageResourceSize,
-    Value storageBuffer, Value storageBufferSize, IndexSet &indexSet,
-    OpBuilder &builder) {
+    Value storageBuffer, Value storageBufferSize, IntegerSet<int64_t> &i64Set,
+    IndexSet &indexSet, OpBuilder &builder) {
   // Try mapping; this may fail if the device can't use the storage buffer as
   // the type of resource requested.
   auto tryMapOp = builder.create<IREE::Stream::ResourceTryMapOp>(
@@ -443,7 +443,7 @@
         auto readResult =
             buildFileRead(loc, awaitTimepoint, affinityAttr, resourceType,
                           storageResourceSize, storageBuffer, storageBufferSize,
-                          indexSet, elseBuilder);
+                          i64Set, indexSet, elseBuilder);
         elseBuilder.create<scf::YieldOp>(loc, ValueRange{
                                                   readResult.timepoint,
                                                   readResult.resource,
@@ -457,7 +457,8 @@
 static Value generateSerializedUpload(
     Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr,
     IREE::Stream::ResourceConfigAttr resourceConfig,
-    ArrayRef<ConstantSlice> slices, IndexSet &indexSet, OpBuilder &builder) {
+    ArrayRef<ConstantSlice> slices, IntegerSet<int64_t> &i64Set,
+    IndexSet &indexSet, OpBuilder &builder) {
   // Perform the packing of dense values to compute the storage resources we
   // will need and where each value will be placed.
   auto storageResources =
@@ -491,11 +492,11 @@
     if (resourceType.getLifetime() == IREE::Stream::Lifetime::Constant) {
       uploadedResource = buildTryMapConstantResource(
           storageResource.loc, currentTimepoint, affinityAttr, resourceType,
-          resourceSize, storageBuffer, resourceSize, indexSet, builder);
+          resourceSize, storageBuffer, resourceSize, i64Set, indexSet, builder);
     } else {
       uploadedResource = buildFileRead(
           storageResource.loc, currentTimepoint, affinityAttr, resourceType,
-          resourceSize, storageBuffer, resourceSize, indexSet, builder);
+          resourceSize, storageBuffer, resourceSize, i64Set, indexSet, builder);
     }
 
     for (auto &span : storageResource.spans) {
@@ -516,7 +517,8 @@
 static Value generateParameterUpload(
     Value awaitTimepoint, IREE::Stream::AffinityAttr affinityAttr,
     IREE::Stream::ResourceConfigAttr resourceConfig,
-    ArrayRef<ConstantSlice> slices, IndexSet &indexSet, OpBuilder &builder) {
+    ArrayRef<ConstantSlice> slices, IntegerSet<int64_t> &i64Set,
+    IndexSet &indexSet, OpBuilder &builder) {
   auto anyResult = slices.front().result;
   auto resourceType =
       llvm::cast<IREE::Stream::ResourceType>(anyResult.getType());
@@ -574,7 +576,7 @@
   for (auto &[scope, scopeResources] : resourceLoads) {
     uploadTimepoints.push_back(
         buildParameterLoad(awaitTimepoint, affinityAttr, resourceType, scope,
-                           scopeResources, indexSet, builder));
+                           scopeResources, i64Set, indexSet, builder));
   }
 
   // Emit gathers, of which there may be multiple batches based on the target
@@ -583,7 +585,7 @@
     auto resourceSize = indexSet.get(storageResource->totalSize);
     auto uploadedResource = buildParameterGather(
         storageResource->loc, awaitTimepoint, affinityAttr, resourceType,
-        resourceSize, storageResource->spans, indexSet, builder);
+        resourceSize, storageResource->spans, i64Set, indexSet, builder);
     uploadTimepoints.push_back(uploadedResource.timepoint);
   }
 
@@ -594,7 +596,8 @@
 static Value generateUploads(Value awaitTimepoint,
                              IREE::Stream::ResourceConstantsOp constantsOp,
                              IREE::Stream::ResourceConfigAttr resourceConfig,
-                             IndexSet &indexSet, OpBuilder &builder) {
+                             IntegerSet<int64_t> &i64Set, IndexSet &indexSet,
+                             OpBuilder &builder) {
   // Split the slices based on whether they are sourced from serialized data or
   // externally-defined parameters.
   // TODO(benvanik): remove stream.resource.constants and this coupling;
@@ -622,12 +625,12 @@
   if (!serializedSlices.empty()) {
     uploadTimepoints.push_back(generateSerializedUpload(
         awaitTimepoint, constantsOp.getAffinityAttr(), resourceConfig,
-        serializedSlices, indexSet, builder));
+        serializedSlices, i64Set, indexSet, builder));
   }
   if (!parameterSlices.empty()) {
     uploadTimepoints.push_back(generateParameterUpload(
         awaitTimepoint, constantsOp.getAffinityAttr(), resourceConfig,
-        parameterSlices, indexSet, builder));
+        parameterSlices, i64Set, indexSet, builder));
   }
   return IREE::Stream::TimepointJoinOp::join(uploadTimepoints, builder);
 }
@@ -664,14 +667,16 @@
       // statically-known - CSE would collapse them but we use an IndexSet to
       // reduce the IR churn.
       OpBuilder builder(constantsOp);
+      IntegerSet<int64_t> i64Set(constantsOp.getLoc(), builder);
       IndexSet indexSet(constantsOp.getLoc(), builder);
       indexSet.populate(constantsOp.getResultSizes());
 
       // Perform upload/processing for immutable and mutable constants.
       Value awaitTimepoint = builder.create<IREE::Stream::TimepointImmediateOp>(
           constantsOp.getLoc());
-      auto uploadTimepoint = generateUploads(awaitTimepoint, constantsOp,
-                                             resourceConfig, indexSet, builder);
+      auto uploadTimepoint =
+          generateUploads(awaitTimepoint, constantsOp, resourceConfig, i64Set,
+                          indexSet, builder);
       constantsOp.getResultTimepoint().replaceAllUsesWith(uploadTimepoint);
 
       constantsOp.erase();
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp
index a741d36..36371f8 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/SpecializeDispatches.cpp
@@ -9,7 +9,7 @@
 
 #include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
 #include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp
index 3c219bd..c5f5344 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/PropagateSubranges.cpp
@@ -11,7 +11,7 @@
 #include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
 #include "iree/compiler/Dialect/Util/Transforms/Passes.h"
 #include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "llvm/ADT/BreadthFirstIterator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp
index 97ad935..cbec9cb 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.cpp
@@ -13,7 +13,7 @@
 #include "iree/compiler/Dialect/VMVX/IR/VMVXDialect.h"
 #include "iree/compiler/Dialect/VMVX/IR/VMVXOps.h"
 #include "iree/compiler/Dialect/VMVX/IR/VMVXTypes.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Attributes.h"
diff --git a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
index 45b2395..27c81b8 100644
--- a/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
+++ b/compiler/src/iree/compiler/Dialect/VMVX/Transforms/ResolveBufferDescriptors.cpp
@@ -8,7 +8,7 @@
 #include "iree/compiler/Dialect/VMVX/IR/VMVXDialect.h"
 #include "iree/compiler/Dialect/VMVX/Transforms/PassDetail.h"
 #include "iree/compiler/Dialect/VMVX/Transforms/Passes.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
diff --git a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp
index d652ace..5120999 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/ExpandTensorShapes.cpp
@@ -13,7 +13,7 @@
 #include "iree/compiler/Dialect/Util/Transforms/Patterns.h"
 #include "iree/compiler/GlobalOptimization/PassDetail.h"
 #include "iree/compiler/GlobalOptimization/Passes.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "llvm/ADT/BreadthFirstIterator.h"
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp
index 89edd28..8270941 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/InlineExecutables.cpp
@@ -11,7 +11,7 @@
 #include "iree/compiler/Modules/HAL/Inline/IR/HALInlineDialect.h"
 #include "iree/compiler/Modules/HAL/Inline/Transforms/PassDetail.h"
 #include "iree/compiler/Modules/HAL/Inline/Transforms/Passes.h"
-#include "iree/compiler/Utils/IndexSet.h"
+#include "iree/compiler/Utils/IntegerSet.h"
 #include "iree/compiler/Utils/ModuleUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
diff --git a/compiler/src/iree/compiler/Utils/BUILD.bazel b/compiler/src/iree/compiler/Utils/BUILD.bazel
index 170bd9a..dbcdc01 100644
--- a/compiler/src/iree/compiler/Utils/BUILD.bazel
+++ b/compiler/src/iree/compiler/Utils/BUILD.bazel
@@ -34,8 +34,8 @@
         "EquivalenceUtils.h",
         "FlatbufferUtils.h",
         "Folding.h",
-        "IndexSet.h",
         "Indexing.h",
+        "IntegerSet.h",
         "ModuleUtils.h",
         "OpVisitor.h",
         "OptionUtils.h",
diff --git a/compiler/src/iree/compiler/Utils/CMakeLists.txt b/compiler/src/iree/compiler/Utils/CMakeLists.txt
index 114c9a5..c4f20b2 100644
--- a/compiler/src/iree/compiler/Utils/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Utils/CMakeLists.txt
@@ -19,8 +19,8 @@
     "EquivalenceUtils.h"
     "FlatbufferUtils.h"
     "Folding.h"
-    "IndexSet.h"
     "Indexing.h"
+    "IntegerSet.h"
     "ModuleUtils.h"
     "OpVisitor.h"
     "OptionUtils.h"
diff --git a/compiler/src/iree/compiler/Utils/IndexSet.h b/compiler/src/iree/compiler/Utils/IndexSet.h
deleted file mode 100644
index 055b0ca..0000000
--- a/compiler/src/iree/compiler/Utils/IndexSet.h
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2021 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#ifndef IREE_COMPILER_UTILS_INDEXSET_H_
-#define IREE_COMPILER_UTILS_INDEXSET_H_
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Matchers.h"
-
-namespace mlir::iree_compiler {
-
-// Simple cache for generated index values.
-// Always inserts at the location specified by the builder when constructed.
-class IndexSet {
-public:
-  explicit IndexSet(Location loc, OpBuilder builder)
-      : loc(loc), builder(builder) {}
-
-  Value get(int64_t value) {
-    auto it = memoizedIndices.find(value);
-    if (it != memoizedIndices.end())
-      return it->second;
-    auto memoizedValue =
-        builder.create<arith::ConstantIndexOp>(loc, value).getResult();
-    memoizedIndices[value] = memoizedValue;
-    return memoizedValue;
-  }
-  Value get(APInt value) { return get(value.getSExtValue()); }
-
-  void populate(ValueRange values) {
-    for (auto value : values) {
-      APInt intValue;
-      if (matchPattern(value, m_ConstantInt(&intValue))) {
-        memoizedIndices.insert(std::make_pair(intValue.getSExtValue(), value));
-      }
-    }
-  }
-
-private:
-  Location loc;
-  OpBuilder builder;
-  DenseMap<int64_t, Value> memoizedIndices;
-};
-
-} // namespace mlir::iree_compiler
-
-#endif // IREE_COMPILER_UTILS_INDEXSET_H_
diff --git a/compiler/src/iree/compiler/Utils/IntegerSet.h b/compiler/src/iree/compiler/Utils/IntegerSet.h
new file mode 100644
index 0000000..594eecd
--- /dev/null
+++ b/compiler/src/iree/compiler/Utils/IntegerSet.h
@@ -0,0 +1,86 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_UTILS_INTEGERSET_H_
+#define IREE_COMPILER_UTILS_INTEGERSET_H_
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Matchers.h"
+
+namespace mlir::iree_compiler {
+
+// Simple cache for generated integer values with the specified storage type.
+// Always inserts at the location specified by the builder when constructed.
+template <typename StorageT>
+class IntegerSet {
+public:
+  explicit IntegerSet(Location loc, OpBuilder builder)
+      : loc(loc), builder(builder) {}
+
+  Value get(StorageT value) { return get(APInt(sizeof(StorageT) * 8, value)); }
+  Value get(APInt value) {
+    auto it = memoizedValues.find(value);
+    if (it != memoizedValues.end()) {
+      return it->second;
+    }
+    Value memoizedValue = builder.create<arith::ConstantIntOp>(
+        loc, *value.getRawData(), value.getBitWidth());
+    memoizedValues[value] = memoizedValue;
+    return memoizedValue;
+  }
+
+  void populate(ValueRange values) {
+    for (auto value : values) {
+      APInt intValue;
+      if (matchPattern(value, m_ConstantInt(&intValue))) {
+        memoizedValues.insert(std::make_pair(intValue, value));
+      }
+    }
+  }
+
+private:
+  Location loc;
+  OpBuilder builder;
+  DenseMap<APInt, Value> memoizedValues;
+};
+
+// Simple cache for generated index values.
+// Always inserts at the location specified by the builder when constructed.
+class IndexSet {
+public:
+  explicit IndexSet(Location loc, OpBuilder builder)
+      : loc(loc), builder(builder) {}
+
+  Value get(int64_t value) {
+    auto it = memoizedIndices.find(value);
+    if (it != memoizedIndices.end()) {
+      return it->second;
+    }
+    Value memoizedValue = builder.create<arith::ConstantIndexOp>(loc, value);
+    memoizedIndices[value] = memoizedValue;
+    return memoizedValue;
+  }
+  Value get(APInt value) { return get(value.getSExtValue()); }
+
+  void populate(ValueRange values) {
+    for (auto value : values) {
+      APInt intValue;
+      if (matchPattern(value, m_ConstantInt(&intValue))) {
+        memoizedIndices.insert(std::make_pair(intValue.getSExtValue(), value));
+      }
+    }
+  }
+
+private:
+  Location loc;
+  OpBuilder builder;
+  DenseMap<int64_t, Value> memoizedIndices;
+};
+
+} // namespace mlir::iree_compiler
+
+#endif // IREE_COMPILER_UTILS_INTEGERSET_H_