Add simple pass to turn dense attributes into dense_resource attributes. (#14574)
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
index c0d0b57..524113a 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
@@ -25,6 +25,7 @@
"FuseGlobals.cpp",
"HoistIntoGlobals.cpp",
"IPO.cpp",
+ "ImportResources.cpp",
"PassDetail.h",
"Passes.cpp",
"Patterns.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
index 284af5f..8ae9513 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
@@ -28,6 +28,7 @@
"FuseGlobals.cpp"
"HoistIntoGlobals.cpp"
"IPO.cpp"
+ "ImportResources.cpp"
"PassDetail.h"
"Passes.cpp"
"Patterns.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp
new file mode 100644
index 0000000..625f680
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/ImportResources.cpp
@@ -0,0 +1,206 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Util/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#define DEBUG_TYPE "iree-util-import-resources"
+
+namespace mlir::iree_compiler::IREE::Util {
+
+namespace {
+
+// TODO: Just use the DenseResourceElementsAttr::get()
+// builder once https://reviews.llvm.org/D157064 lands.
+class DenseBlobResourceElementsAttr : public DenseResourceElementsAttr {
+public:
+ using DenseResourceElementsAttr::get;
+};
+
+template <typename ElementType, unsigned numBits = sizeof(ElementType) * 8>
+static void copyIntAttrIntoBlob(AsmResourceBlob &blob,
+ DenseIntElementsAttr attr) {
+ ArrayRef<ElementType> data = blob.getDataAs<ElementType>();
+ MutableArrayRef<ElementType> rwData = MutableArrayRef<ElementType>(
+ const_cast<ElementType *>(data.data()), data.size());
+ ArrayRef<char> rawSrcData = attr.getRawData();
+ if (rawSrcData.size() == blob.getData().size()) {
+ // Memcpy.
+ std::memcpy(rwData.data(), rawSrcData.data(), rawSrcData.size());
+ } else {
+ // Slow.
+ size_t index = 0;
+ for (APInt value : attr.getValues<APInt>()) {
+ rwData[index++] = value.extractBitsAsZExtValue(numBits, 0);
+ }
+ }
+}
+
+template <typename ElementType, unsigned numBits = sizeof(ElementType) * 8>
+static void copyFPAttrIntoBlob(AsmResourceBlob &blob,
+ DenseFPElementsAttr attr) {
+ ArrayRef<ElementType> data = blob.getDataAs<ElementType>();
+ MutableArrayRef<ElementType> rwData = MutableArrayRef<ElementType>(
+ const_cast<ElementType *>(data.data()), data.size());
+ ArrayRef<char> rawSrcData = attr.getRawData();
+ if (rawSrcData.size() == blob.getData().size()) {
+ // Memcpy.
+ std::memcpy(rwData.data(), rawSrcData.data(), rawSrcData.size());
+ } else {
+ // Slow.
+ size_t index = 0;
+ for (APFloat value : attr.getValues<APFloat>()) {
+ rwData[index++] =
+ value.bitcastToAPInt().extractBitsAsZExtValue(numBits, 0);
+ }
+ }
+}
+
+class ImportResourcesPass : public ImportResourcesBase<ImportResourcesPass> {
+public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<BuiltinDialect>();
+ }
+
+ void runOnOperation() override {
+ llvm::DenseMap<Attribute, Attribute> replacements;
+
+ getOperation()->walk([&](Operation *op) {
+ bool updated = false;
+ SmallVector<NamedAttribute> attrs(op->getAttrs());
+ for (auto &attr : attrs) {
+ if (auto elements = llvm::dyn_cast<ElementsAttr>(attr.getValue())) {
+ // Already seen?
+ auto it = replacements.find(elements);
+ if (it != replacements.end()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << ":: Replacing already encountered attr of "
+ << elements.getType() << "\n");
+ attr.setValue(it->second);
+ updated = true;
+ continue;
+ }
+
+ // Convert.
+ if (shouldConvertElements(elements)) {
+ LLVM_DEBUG(llvm::dbgs() << ":: Converting elements attr of "
+ << elements.getType() << "\n");
+ if (auto replacement = convertElementsAttr(elements)) {
+ attr.setValue(replacement);
+ replacements[elements] = replacement;
+ updated = true;
+ } else {
+ LLVM_DEBUG(llvm::dbgs() << " Failed to convert\n");
+ }
+ }
+ }
+ }
+ if (updated)
+ op->setAttrs(attrs);
+ });
+ LLVM_DEBUG(llvm::dbgs() << "DONE CONVERTING RESOURCES\n");
+ }
+
+ static bool shouldConvertElements(ElementsAttr attr) {
+ if (llvm::isa<DenseElementsAttr>(attr)) {
+ // DenseElementsAttr encodes arbitrary dimension
+ // splats whereas DenseResourceElementsAttr does not.
+ return !attr.isSplat();
+ }
+
+ return false;
+ }
+
+ static ElementsAttr convertElementsAttr(ElementsAttr elementsAttr) {
+ auto st = llvm::cast<ShapedType>(elementsAttr.getType());
+ auto elementType = st.getElementType();
+ auto numElements = elementsAttr.getNumElements();
+ auto bitWidth = elementType.getIntOrFloatBitWidth();
+ AsmResourceBlob blob;
+ if (auto attr = llvm::dyn_cast<DenseIntElementsAttr>(elementsAttr)) {
+ switch (bitWidth) {
+ case 1:
+ blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64,
+ /*dataIsMutable=*/true);
+ copyIntAttrIntoBlob<uint8_t, /*numBits=*/1>(blob, attr);
+ return DenseBlobResourceElementsAttr::get(st, "dense_elements_i1",
+ std::move(blob));
+ case 8:
+ blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64,
+ /*dataIsMutable=*/true);
+ copyIntAttrIntoBlob<uint8_t>(blob, attr);
+ return DenseBlobResourceElementsAttr::get(st, "dense_elements_i8",
+ std::move(blob));
+ case 16:
+ blob = HeapAsmResourceBlob::allocate(2 * numElements, /*align=*/64,
+ /*dataIsMutable=*/true);
+ copyIntAttrIntoBlob<uint16_t>(blob, attr);
+ return DenseBlobResourceElementsAttr::get(st, "dense_elements_i16",
+ std::move(blob));
+ case 32:
+ blob = HeapAsmResourceBlob::allocate(4 * numElements, /*align=*/64,
+ /*dataIsMutable=*/true);
+ copyIntAttrIntoBlob<uint32_t>(blob, attr);
+ return DenseBlobResourceElementsAttr::get(st, "dense_elements_i32",
+ std::move(blob));
+ case 64:
+ blob = HeapAsmResourceBlob::allocate(8 * numElements, /*align=*/64,
+ /*dataIsMutable=*/true);
+ copyIntAttrIntoBlob<uint64_t>(blob, attr);
+ return DenseBlobResourceElementsAttr::get(st, "dense_elements_i64",
+ std::move(blob));
+ default:
+ return {};
+ }
+ } else if (auto attr = llvm::dyn_cast<DenseFPElementsAttr>(elementsAttr)) {
+ AsmResourceBlob blob;
+ switch (bitWidth) {
+ case 8:
+ blob = HeapAsmResourceBlob::allocate(numElements, /*align=*/64,
+ /*dataIsMutable=*/true);
+ copyFPAttrIntoBlob<uint8_t>(blob, attr);
+ return DenseBlobResourceElementsAttr::get(st, "dense_elements_f8",
+ std::move(blob));
+ case 16:
+ blob = HeapAsmResourceBlob::allocate(2 * numElements, /*align=*/64,
+ /*dataIsMutable=*/true);
+ copyFPAttrIntoBlob<uint16_t>(blob, attr);
+ return DenseBlobResourceElementsAttr::get(st, "dense_elements_f16",
+ std::move(blob));
+ case 32:
+ blob = HeapAsmResourceBlob::allocate(4 * numElements, /*align=*/64,
+ /*dataIsMutable=*/true);
+ copyFPAttrIntoBlob<uint32_t>(blob, attr);
+ return DenseBlobResourceElementsAttr::get(st, "dense_elements_f32",
+ std::move(blob));
+ case 64:
+ blob = HeapAsmResourceBlob::allocate(8 * numElements, /*align=*/64,
+ /*dataIsMutable=*/true);
+ copyFPAttrIntoBlob<uint64_t>(blob, attr);
+ return DenseBlobResourceElementsAttr::get(st, "dense_elements_f64",
+ std::move(blob));
+ default:
+ return {};
+ }
+ }
+ return {};
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<void>> createImportResourcesPass() {
+ return std::make_unique<ImportResourcesPass>();
+}
+
+} // namespace mlir::iree_compiler::IREE::Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
index 39c219f..abf1c19 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
@@ -29,6 +29,9 @@
std::unique_ptr<OperationPass<void>> createSimplifyGlobalAccessesPass();
std::unique_ptr<OperationPass<void>> createStripDebugOpsPass();
+// Resource Management.
+std::unique_ptr<OperationPass<void>> createImportResourcesPass();
+
// Type conversion.
std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteI64ToI32Pass();
std::unique_ptr<OperationPass<mlir::ModuleOp>> createDemoteF32ToF16Pass();
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
index 0da39c0..873b1ab 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
@@ -102,6 +102,27 @@
}
//===----------------------------------------------------------------------===//
+// Resource Management
+//===----------------------------------------------------------------------===//
+
+def ImportResources : Pass<"iree-util-import-resources", ""> {
+ let summary = "Imports IR with arbitrary large-data into resources that IREE can manage efficiently";
+ let description = [{
+ MLIR has many interesting ways to store large constants, most of which
+ derive from *ElementsAttr. Given the uniquing/inline behavior, this exacts
+ very large runtime and memory overhead costs.
+
+ This is a temporary pass to convert a majority of the legacy
+ DenseElementsAttr attributes to DenseResourceElementsAttr. Ideally this
+ is done at the source (frontend), but this pass is provided to aid
+ transition and testing by doing a manual conversion with iree-opt.
+ }];
+ let constructor = [{
+ mlir::iree_compiler::IREE::Util::createImportResourcesPass()
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
index db4655d..c473c64 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
@@ -27,6 +27,7 @@
"fuse_globals.mlir",
"hoist_into_globals.mlir",
"hoist_into_globals_linalg.mlir",
+ "import_resources.mlir",
"ipo.mlir",
"promote_bf16_to_f32.mlir",
"promote_f16_to_f32.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
index f855c45..dfc917f 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
@@ -25,6 +25,7 @@
"fuse_globals.mlir"
"hoist_into_globals.mlir"
"hoist_into_globals_linalg.mlir"
+ "import_resources.mlir"
"ipo.mlir"
"promote_bf16_to_f32.mlir"
"promote_f16_to_f32.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir
new file mode 100644
index 0000000..e8b3b50
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/import_resources.mlir
@@ -0,0 +1,89 @@
+// RUN: iree-opt --split-input-file --iree-util-import-resources %s | FileCheck %s
+
+// CHECK-LABEL: func.func @constant_splat_i64
+func.func @constant_splat_i64() -> tensor<4xi64> {
+ // Splats should not convert.
+ // CHECK-NEXT: constant dense<123>
+ %c123 = arith.constant dense<123> : tensor<4xi64>
+ return %c123 : tensor<4xi64>
+}
+
+// -----
+// CHECK-LABEL: func.func @dense_i1
+func.func @dense_i1() -> tensor<4xi1> {
+ // CHECK: dense_resource<dense_elements_i1>
+ %c123 = arith.constant dense<[true, false, false, true]> : tensor<4xi1>
+ return %c123 : tensor<4xi1>
+}
+
+// CHECK: dense_elements_i1: "0x4000000001000001"
+
+// -----
+// CHECK-LABEL: func.func @dense_i8
+func.func @dense_i8() -> tensor<4xi8> {
+ // CHECK: dense_resource<dense_elements_i8>
+ %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi8>
+ return %c123 : tensor<4xi8>
+}
+
+// CHECK: dense_elements_i8: "0x400000000102037F"
+
+// -----
+// CHECK-LABEL: func.func @dense_i16
+func.func @dense_i16() -> tensor<4xi16> {
+ // CHECK: dense_resource<dense_elements_i16>
+ %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi16>
+ return %c123 : tensor<4xi16>
+}
+
+// CHECK: dense_elements_i16: "0x400000000100020003007F00"
+
+// -----
+// CHECK-LABEL: func.func @dense_i32
+func.func @dense_i32() -> tensor<4xi32> {
+ // CHECK: dense_resource<dense_elements_i32>
+ %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi32>
+ return %c123 : tensor<4xi32>
+}
+
+// CHECK: dense_elements_i32: "0x400000000100000002000000030000007F000000"
+
+// -----
+// CHECK-LABEL: func.func @dense_i64
+func.func @dense_i64() -> tensor<4xi64> {
+ // CHECK: dense_resource<dense_elements_i64>
+ %c123 = arith.constant dense<[1, 2, 3, 127]> : tensor<4xi64>
+ return %c123 : tensor<4xi64>
+}
+
+// CHECK: dense_elements_i64: "0x400000000100000000000000020000000000000003000000000000007F00000000000000"
+
+// -----
+// CHECK-LABEL: func.func @dense_f16
+func.func @dense_f16() -> tensor<4xf16> {
+ // CHECK: dense_resource<dense_elements_f16>
+ %c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf16>
+ return %c123 : tensor<4xf16>
+}
+
+// CHECK: dense_elements_f16: "0x40000000663C66409A420000"
+
+// -----
+// CHECK-LABEL: func.func @dense_f32
+func.func @dense_f32() -> tensor<4xf32> {
+ // CHECK: dense_resource<dense_elements_f32>
+ %c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf32>
+ return %c123 : tensor<4xf32>
+}
+
+// CHECK: dense_elements_f32: "0x40000000CDCC8C3FCDCC0C403333534000000000"
+
+// -----
+// CHECK-LABEL: func.func @dense_f64
+func.func @dense_f64() -> tensor<4xf64> {
+ // CHECK: dense_resource<dense_elements_f64>
+ %c123 = arith.constant dense<[1.1, 2.2, 3.3, 0.0]> : tensor<4xf64>
+ return %c123 : tensor<4xf64>
+}
+
+// CHECK: dense_elements_f64: "0x400000009A9999999999F13F9A999999999901406666666666660A400000000000000000"