Add simple pass to turn dense attributes into dense_resource attributes. (#14574)

diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
index c0d0b57..524113a 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
@@ -25,6 +25,7 @@
         "FuseGlobals.cpp",
         "HoistIntoGlobals.cpp",
         "IPO.cpp",
+        "ImportResources.cpp",
         "PassDetail.h",
         "Passes.cpp",
         "Patterns.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
index 284af5f..8ae9513 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
@@ -28,6 +28,7 @@
     "FuseGlobals.cpp"
     "HoistIntoGlobals.cpp"
     "IPO.cpp"
+    "ImportResources.cpp"
     "PassDetail.h"
     "Passes.cpp"
     "Patterns.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp
new file mode 100644
index 0000000..625f680
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp
@@ -0,0 +1,206 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#define DEBUG_TYPE "iree-util-import-resources"
+
+namespace mlir::iree_compiler::IREE::Util {
+
+namespace {
+
+// TODO: Just use the DenseResourceElementsAttr::get()
+// builder once https://reviews.llvm.org/D157064 lands.
+class DenseBlobResourceElementsAttr : public DenseResourceElementsAttr {
+public:
+  using DenseResourceElementsAttr::get;
+};
+
+template <typename ElementType, unsigned numBits = sizeof(ElementType) * 8>
+static void copyIntAttrIntoBlob(AsmResourceBlob &blob,
+                                DenseIntElementsAttr attr) {
+  ArrayRef<ElementType> data = blob.getDataAs<ElementType>();
+  MutableArrayRef<ElementType> rwData = MutableArrayRef<ElementType>(
+      const_cast<ElementType *>(data.data()), data.size());
+  ArrayRef<char> rawSrcData = attr.getRawData();
+  if (rawSrcData.size() == blob.getData().size()) {
+    // Memcpy.
+    std::memcpy(rwData.data(), rawSrcData.data(), rawSrcData.size());
+  } else {
+    // Slow.
+    size_t index = 0;
+    for (APInt value : attr.getValues<APInt>()) {
+      rwData[index++] = value.extractBitsAsZExtValue(numBits, 0);
+    }
+  }
+}
+
+template <typename ElementType, unsigned numBits = sizeof(ElementType) * 8>
+static void copyFPAttrIntoBlob(AsmResourceBlob &blob,
+                               DenseFPElementsAttr attr) {
+  ArrayRef<ElementType> data = blob.getDataAs<ElementType>();
+  MutableArrayRef<ElementType> rwData = MutableArrayRef<ElementType>(
+      const_cast<ElementType *>(data.data()), data.size());
+  ArrayRef<char> rawSrcData = attr.getRawData();
+  if (rawSrcData.size() == blob.getData().size()) {
+    // Memcpy.
+    std::memcpy(rwData.data(), rawSrcData.data(), rawSrcData.size());
+  } else {
+    // Slow.
+    size_t index = 0;
+    for (APFloat value : attr.getValues<APFloat>()) {
+      rwData[index++] =
+          value.bitcastToAPInt().extractBitsAsZExtValue(numBits, 0);
+    }
+  }
+}
+
+class ImportResourcesPass : public ImportResourcesBase<ImportResourcesPass> {
+public:
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<BuiltinDialect>();
+  }
+
+  void runOnOperation() override {
+    llvm::DenseMap<Attribute, Attribute> replacements;
+
+    getOperation()->walk([&](Operation *op) {
+      bool updated = false;
+      SmallVector<NamedAttribute> attrs(op->getAttrs());
+      for (auto &attr : attrs) {
+        if (auto elements = llvm::dyn_cast<ElementsAttr>(attr.getValue())) {
+          // Already seen?
+          auto it = replacements.find(elements);
+          if (it != replacements.end()) {
+            LLVM_DEBUG(llvm::dbgs()
+                       << ":: Replacing already encountered attr of "
+                       << elements.getType() << "\n");
+            attr.setValue(it->second);
+            updated = true;
+            continue;
+          }
+
+          // Convert.
+          if (shouldConvertElements(elements)) {
+            LLVM_DEBUG(llvm::dbgs() << ":: Converting elements attr of "
+                                    << elements.getType() << "\n");
+            if (auto replacement = convertElementsAttr(elements)) {
+              attr.setValue(replacement);
+              replacements[elements] = replacement;
+              updated = true;
+            } else {
+              LLVM_DEBUG(llvm::dbgs() << "  Failed to convert\n");
+            }
+          }
+        }
+      }
+      if (updated)
+        op->setAttrs(attrs);
+    });
+    LLVM_DEBUG(llvm::dbgs() << "DONE CONVERTING RESOURCES\n");
+  }
+
+  static bool shouldConvertElements(ElementsAttr attr) {
+    if (llvm::isa<DenseElementsAttr>(attr)) {
+      // DenseElementsAttr encodes arbitrary dimension
+      // splats whereas DenseResourceElementsAttr does not.
+      return !attr.isSplat();
+    }
+
+    return false;
+  }
+
+  static ElementsAttr convertElementsAttr(ElementsAttr elementsAttr) {
+    auto st = llvm::cast<ShapedType>(elementsAttr.getType());
+    auto elementType = st.getElementType();
+    auto numElements = elementsAttr.getNumElements();
+    auto bitWidth = elementType.getIntOrFloatBitWidth();
+    AsmResourceBlob blob;
+    if (auto attr = llvm::dyn_cast<DenseIntElementsAttr>(elementsAttr)) {
+      switch (bitWidth) {
+      case 1:
+        blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64,
+                                             /*dataIsMutable=*/true);
+        copyIntAttrIntoBlob<uint8_t, /*numBits=*/1>(blob, attr);
+        return DenseBlobResourceElementsAttr::get(st, "dense_elements_i1",
+                                                  std::move(blob));
+      case 8:
+        blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64,
+                                             /*dataIsMutable=*/true);
+        copyIntAttrIntoBlob<uint8_t>(blob, attr);
+        return DenseBlobResourceElementsAttr::get(st, "dense_elements_i8",
+                                                  std::move(blob));
+      case 16:
+        blob = HeapAsmResourceBlob::allocate(2 * numElements, /*align=*/64,
+                                             /*dataIsMutable=*/true);
+        copyIntAttrIntoBlob<uint16_t>(blob, attr);
+        return DenseBlobResourceElementsAttr::get(st, "dense_elements_i16",
+                                                  std::move(blob));
+      case 32:
+        blob = HeapAsmResourceBlob::allocate(4 * numElements, /*align=*/64,
+                                             /*dataIsMutable=*/true);
+        copyIntAttrIntoBlob<uint32_t>(blob, attr);
+        return DenseBlobResourceElementsAttr::get(st, "dense_elements_i32",
+                                                  std::move(blob));
+      case 64:
+        blob = HeapAsmResourceBlob::allocate(8 * numElements, /*align=*/64,
+                                             /*dataIsMutable=*/true);
+        copyIntAttrIntoBlob<uint64_t>(blob, attr);
+        return DenseBlobResourceElementsAttr::get(st, "dense_elements_i64",
+                                                  std::move(blob));
+      default:
+        return {};
+      }
+    } else if (auto attr = llvm::dyn_cast<DenseFPElementsAttr>(elementsAttr)) {
+      AsmResourceBlob blob;
+      switch (bitWidth) {
+      case 8:
+        blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64,
+                                             /*dataIsMutable=*/true);
+        copyFPAttrIntoBlob<uint8_t>(blob, attr);
+        return DenseBlobResourceElementsAttr::get(st, "dense_elements_f8",
+                                                  std::move(blob));
+      case 16:
+        blob = HeapAsmResourceBlob::allocate(2 * numElements, /*align=*/64,
+                                             /*dataIsMutable=*/true);
+        copyFPAttrIntoBlob<uint16_t>(blob, attr);
+        return DenseBlobResourceElementsAttr::get(st, "dense_elements_f16",
+                                                  std::move(blob));
+      case 32:
+        blob = HeapAsmResourceBlob::allocate(4 * numElements, /*align=*/64,
+                                             /*dataIsMutable=*/true);
+        copyFPAttrIntoBlob<uint32_t>(blob, attr);
+        return DenseBlobResourceElementsAttr::get(st, "dense_elements_f32",
+                                                  std::move(blob));
+      case 64:
+        blob = HeapAsmResourceBlob::allocate(8 * numElements, /*align=*/64,
+                                             /*dataIsMutable=*/true);
+        copyFPAttrIntoBlob<uint64_t>(blob, attr);
+        return DenseBlobResourceElementsAttr::get(st, "dense_elements_f64",
+                                                  std::move(blob));
+      default:
+        return {};
+      }
+    }
+    return {};
+  }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<void>> createImportResourcesPass() {
+  return std::make_unique<ImportResourcesPass>();
+}
+
+} // namespace mlir::iree_compiler::IREE::Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
index 39c219f..abf1c19 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
@@ -29,6 +29,9 @@
 std::unique_ptr<OperationPass<void>> createSimplifyGlobalAccessesPass();
 std::unique_ptr<OperationPass<void>> createStripDebugOpsPass();
 
+// Resource Management.
+std::unique_ptr<OperationPass<void>> createImportResourcesPass();
+
 // Type conversion.
 std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteI64ToI32Pass();
 std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF32ToF16Pass();
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
index 0da39c0..873b1ab 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
@@ -102,6 +102,27 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Resource Management
+//===----------------------------------------------------------------------===//
+
+def ImportResources : Pass<"iree-util-import-resources", ""> {
+  let summary = "Imports IR with arbitrary large-data into resources that IREE can manage efficiently";
+  let description = [{
+    MLIR has many interesting ways to store large constants, most of which
+    derive from *ElementsAttr. Given the uniquing/inline behavior, this exacts
+    very large runtime and memory overhead costs.
+
+    This is a temporary pass to convert a majority of the legacy 
+    DenseElementsAttr attributes to DenseResourceElementsAttr. Ideally this
+    is done at the source (frontend), but this pass is provided to aid
+    transition and testing by doing a manual conversion with iree-opt.
+  }];
+  let constructor = [{
+    mlir::iree_compiler::IREE::Util::createImportResourcesPass()
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
index db4655d..c473c64 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
@@ -27,6 +27,7 @@
             "fuse_globals.mlir",
             "hoist_into_globals.mlir",
             "hoist_into_globals_linalg.mlir",
+            "import_resources.mlir",
             "ipo.mlir",
             "promote_bf16_to_f32.mlir",
             "promote_f16_to_f32.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
index f855c45..dfc917f 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
@@ -25,6 +25,7 @@
     "fuse_globals.mlir"
     "hoist_into_globals.mlir"
     "hoist_into_globals_linalg.mlir"
+    "import_resources.mlir"
     "ipo.mlir"
     "promote_bf16_to_f32.mlir"
     "promote_f16_to_f32.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir
new file mode 100644
index 0000000..e8b3b50
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir
@@ -0,0 +1,89 @@
+// RUN: iree-opt --split-input-file --iree-util-import-resources %s | FileCheck %s
+
+// CHECK-LABEL: func.func @constant_splat_i64
+func.func @constant_splat_i64() -> tensor<4xi64> {
+  // Splats should not convert.
+  // CHECK-NEXT: constant dense<123>
+  %c123 = arith.constant dense<123> : tensor<4xi64>
+  return %c123 : tensor<4xi64>
+}
+
+// -----
+// CHECK-LABEL: func.func @dense_i1
+func.func @dense_i1() -> tensor<4xi1> {
+  // CHECK: dense_resource<dense_elements_i1>
+  %c123 = arith.constant dense<[true, false, false, true]> : tensor<4xi1>
+  return %c123 : tensor<4xi1>
+}
+
+// CHECK: dense_elements_i1: "0x4000000001000001"
+
+// -----
+// CHECK-LABEL: func.func @dense_i8
+func.func @dense_i8() -> tensor<4xi8> {
+  // CHECK: dense_resource<dense_elements_i8>
+  %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi8>
+  return %c123 : tensor<4xi8>
+}
+
+// CHECK: dense_elements_i8: "0x400000000102037F"
+
+// -----
+// CHECK-LABEL: func.func @dense_i16
+func.func @dense_i16() -> tensor<4xi16> {
+  // CHECK: dense_resource<dense_elements_i16>
+  %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi16>
+  return %c123 : tensor<4xi16>
+}
+
+// CHECK: dense_elements_i16: "0x400000000100020003007F00"
+
+// -----
+// CHECK-LABEL: func.func @dense_i32
+func.func @dense_i32() -> tensor<4xi32> {
+  // CHECK: dense_resource<dense_elements_i32>
+  %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi32>
+  return %c123 : tensor<4xi32>
+}
+
+// CHECK: dense_elements_i32: "0x400000000100000002000000030000007F000000"
+
+// -----
+// CHECK-LABEL: func.func @dense_i64
+func.func @dense_i64() -> tensor<4xi64> {
+  // CHECK: dense_resource<dense_elements_i64>
+  %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi64>
+  return %c123 : tensor<4xi64>
+}
+
+// CHECK: dense_elements_i64: "0x400000000100000000000000020000000000000003000000000000007F00000000000000"
+
+// -----
+// CHECK-LABEL: func.func @dense_f16
+func.func @dense_f16() -> tensor<4xf16> {
+  // CHECK: dense_resource<dense_elements_f16>
+  %c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf16>
+  return %c123 : tensor<4xf16>
+}
+
+// CHECK: dense_elements_f16: "0x40000000663C66409A420000"
+
+// -----
+// CHECK-LABEL: func.func @dense_f32
+func.func @dense_f32() -> tensor<4xf32> {
+  // CHECK: dense_resource<dense_elements_f32>
+  %c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf32>
+  return %c123 : tensor<4xf32>
+}
+
+// CHECK: dense_elements_f32: "0x40000000CDCC8C3FCDCC0C403333534000000000"
+
+// -----
+// CHECK-LABEL: func.func @dense_f64
+func.func @dense_f64() -> tensor<4xf64> {
+  // CHECK: dense_resource<dense_elements_f64>
+  %c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf64>
+  return %c123 : tensor<4xf64>
+}
+
+// CHECK: dense_elements_f64: "0x400000009A9999999999F13F9A999999999901406666666666660A400000000000000000"