blob: 6f9cfb6d70970d66325f98a4aaf31b706a034332 [file] [log] [blame] [edit]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <utility>
#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
#include "iree/compiler/Dialect/VM/IR/VMOps.h"
#include "iree/compiler/Dialect/VM/IR/VMTypes.h"
#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
namespace mlir::iree_compiler::IREE::VM {
#define GEN_PASS_DEF_REIFYRODATATABLESPASS
#include "iree/compiler/Dialect/VM/Transforms/Passes.h.inc"
// Replaces a vm.rodata.table.inline op with two vm.rodata.inline ops, one for
// the indexing table, and the second for the padded data.
template <typename IntTy>
static void reifyRodataTable(RewriterBase &rewriter,
IREE::VM::RodataTableInlineOp tableOp) {
SmallVector<IntTy> table;
SmallVector<Attribute> dataAttrs;
size_t dataSize = 0;
size_t dataAlignment =
tableOp.getDataAlignment() ? *tableOp.getDataAlignment() : 1;
for (auto value : tableOp.getDataArray().getValue()) {
auto serializableAttr = cast<IREE::Util::SerializableAttrInterface>(value);
size_t storageSize = serializableAttr.getStorageSize();
dataAttrs.push_back(value);
// Pad to the (byte) data alignment.
size_t padding =
(dataAlignment - storageSize % dataAlignment) % dataAlignment;
if (padding) {
SmallVector<int8_t> zeros(padding, 0);
VectorType paddingType = VectorType::get({static_cast<int64_t>(padding)},
rewriter.getIntegerType(8));
dataAttrs.push_back(rewriter.getZeroAttr(paddingType));
}
// The running data size is the offset of the current value.
table.push_back(dataSize);
// The table specifies the (unpadded) storage size for this element.
table.push_back(storageSize);
// Increment the total storage size by the (padded) storage size.
dataSize += storageSize + padding;
}
auto refType =
IREE::VM::RefType::get(rewriter.getType<IREE::VM::BufferType>());
IREE::VM::RodataInlineOp tableRodata;
if constexpr (std::is_same<IntTy, int32_t>()) {
tableRodata = IREE::VM::RodataInlineOp::create(
rewriter, tableOp.getLoc(), refType, rewriter.getI32VectorAttr(table));
} else {
tableRodata = IREE::VM::RodataInlineOp::create(
rewriter, tableOp.getLoc(), refType, rewriter.getI64VectorAttr(table));
}
if (auto tableNameAttr = tableOp.getTableNameAttr()) {
tableRodata.setNameAttr(tableNameAttr);
}
auto dataRodata = IREE::VM::RodataInlineOp::create(
rewriter, tableOp.getLoc(), refType,
IREE::Util::CompositeAttr::get(rewriter.getContext(), dataAttrs));
if (auto dataNameAttr = tableOp.getDataNameAttr()) {
dataRodata.setNameAttr(dataNameAttr);
}
if (auto alignmentAttr = tableOp.getAlignmentAttr()) {
tableRodata.setAlignmentAttr(alignmentAttr);
dataRodata.setAlignmentAttr(alignmentAttr);
}
if (auto mimeTypeAttr = tableOp.getMimeTypeAttr()) {
tableRodata.setMimeTypeAttr(mimeTypeAttr);
dataRodata.setMimeTypeAttr(mimeTypeAttr);
}
rewriter.replaceOp(tableOp, {tableRodata, dataRodata});
}
class ReifyRodataTablesPass
: public IREE::VM::impl::ReifyRodataTablesPassBase<ReifyRodataTablesPass> {
void runOnOperation() override {
IREE::VM::ModuleOp moduleOp = getOperation();
// Walk all of the rodata table ops and convert to rodata.inline
IRRewriter rewriter(moduleOp.getContext());
moduleOp.walk([&](IREE::VM::RodataTableInlineOp tableOp) {
rewriter.setInsertionPoint(tableOp);
size_t tableBitwidth = tableOp.getTableType().getIntOrFloatBitWidth();
if (tableBitwidth == 32) {
reifyRodataTable<int32_t>(rewriter, tableOp);
} else if (tableBitwidth == 64) {
reifyRodataTable<int64_t>(rewriter, tableOp);
} else {
llvm_unreachable("Invalid table bit width");
}
});
}
};
} // namespace mlir::iree_compiler::IREE::VM