Forcing 2-byte alignment on VM register lists. (#6579)
This is a hack for #6566 but not a bad idea in general. Future revisions
of the bytecode format will take this into account natively such that
we don't need any unaligned loads at runtime.
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
index 7e96267..dc85669 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp
@@ -148,7 +148,7 @@
}
LogicalResult encodePrimitiveArrayAttr(DenseElementsAttr value) override {
- if (value.getNumElements() > UINT16_MAX ||
+ if (value.getNumElements() > UINT16_MAX || failed(ensureAlignment(2)) ||
failed(writeUint16(value.getNumElements()))) {
return currentOp_->emitOpError() << "integer array size out of bounds";
}
@@ -187,7 +187,9 @@
// this list is small :)
auto srcDstRegs = registerAllocation_->remapSuccessorRegisters(
currentOp_, successorIndex);
- (void)writeUint16(srcDstRegs.size());
+ if (failed(ensureAlignment(2)) || failed(writeUint16(srcDstRegs.size()))) {
+ return failure();
+ }
for (auto srcDstReg : srcDstRegs) {
if (failed(writeUint16(srcDstReg.first.encode())) ||
failed(writeUint16(srcDstReg.second.encode()))) {
@@ -206,7 +208,10 @@
}
LogicalResult encodeOperands(Operation::operand_range values) override {
- (void)writeUint16(std::distance(values.begin(), values.end()));
+ if (failed(ensureAlignment(2)) ||
+ failed(writeUint16(std::distance(values.begin(), values.end())))) {
+ return failure();
+ }
for (auto it : llvm::enumerate(values)) {
uint16_t reg = registerAllocation_
->mapUseToRegister(it.value(), currentOp_, it.index())
@@ -224,7 +229,10 @@
}
LogicalResult encodeResults(Operation::result_range values) override {
- (void)writeUint16(std::distance(values.begin(), values.end()));
+ if (failed(ensureAlignment(2)) ||
+ failed(writeUint16(std::distance(values.begin(), values.end())))) {
+ return failure();
+ }
for (auto value : values) {
uint16_t reg = registerAllocation_->mapToRegister(value).encode();
if (failed(writeUint16(reg))) {
@@ -241,6 +249,15 @@
return std::move(bytecode_);
}
+ LogicalResult ensureAlignment(size_t alignment) {
+ size_t paddedSize = (bytecode_.size() + (alignment - 1)) & ~(alignment - 1);
+ size_t padding = paddedSize - bytecode_.size();
+ if (padding == 0) return success();
+ static const uint8_t kZeros[32] = {0};
+ if (padding > sizeof(kZeros)) return failure();
+ return writeBytes(kZeros, padding);
+ }
+
private:
// TODO(benvanik): replace this with something not using an ever-expanding
// vector. I'm sure LLVM has something.
@@ -348,6 +365,10 @@
}
}
+ if (failed(encoder.ensureAlignment(8))) {
+ funcOp.emitError() << "failed to pad function";
+ return llvm::None;
+ }
auto bytecodeData = encoder.finish();
if (!bytecodeData.hasValue()) {
funcOp.emitError() << "failed to fixup and finish encoding";
diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir b/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir
index 6c5e100..ca97e3c 100644
--- a/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir
+++ b/iree/compiler/Dialect/VM/Target/Bytecode/test/module_encoding_smoke.mlir
@@ -18,15 +18,18 @@
// CHECK: "function_descriptors":
// CHECK-NEXT: {
// CHECK-NEXT: "bytecode_offset": 0
- // CHECK-NEXT: "bytecode_length": 5
+ // CHECK-NEXT: "bytecode_length": 8
// CHECK-NEXT: "i32_register_count": 1
// CHECK-NEXT: "ref_register_count": 0
// CHECK-NEXT: }
// CHECK: "bytecode_data": [
// CHECK-NEXT: 84,
+ // CHECK-NEXT: 0,
// CHECK-NEXT: 1,
// CHECK-NEXT: 0,
// CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
+ // CHECK-NEXT: 0,
// CHECK-NEXT: 0
// CHECK-NEXT: ]
}
diff --git a/iree/compiler/Translation/test/smoketest.mlir b/iree/compiler/Translation/test/smoketest.mlir
index 1f3b7ca..26306f2 100644
--- a/iree/compiler/Translation/test/smoketest.mlir
+++ b/iree/compiler/Translation/test/smoketest.mlir
@@ -14,12 +14,13 @@
// CHECK: "function_descriptors":
// CHECK-NEXT: {
// CHECK-NEXT: "bytecode_offset": 0
-// CHECK-NEXT: "bytecode_length": 5
+// CHECK-NEXT: "bytecode_length": 8
// CHECK-NEXT: "i32_register_count": 1
// CHECK-NEXT: "ref_register_count": 0
// CHECK-NEXT: }
// CHECK: "bytecode_data": [
// CHECK-NEXT: 84,
+// CHECK-NEXT: 0,
// CHECK-NEXT: 1,
// CHECK-NEXT: 0,
// CHECK-NEXT: 0,
diff --git a/iree/vm/bytecode_dispatch_util.h b/iree/vm/bytecode_dispatch_util.h
index 596968e..81bb1e6 100644
--- a/iree/vm/bytecode_dispatch_util.h
+++ b/iree/vm/bytecode_dispatch_util.h
@@ -169,6 +169,9 @@
// Each macro will increment the pc by the number of bytes read and as such must
// be called in the same order the values are encoded.
+#define VM_AlignPC(pc, alignment) \
+ (pc) = ((pc) + ((alignment)-1)) & ~((alignment)-1)
+
#define VM_DecConstI8(name) \
OP_I8(0); \
++pc;
@@ -201,11 +204,16 @@
(out_str)->data = (const char*)&bytecode_data[pc + 2]; \
pc += 2 + (out_str)->size;
#define VM_DecBranchTarget(block_name) VM_DecConstI32(name)
-#define VM_DecBranchOperands(operands_name) \
- (const iree_vm_register_remap_list_t*)&bytecode_data[pc]; \
- pc += \
- kRegSize + ((const iree_vm_register_list_t*)&bytecode_data[pc])->size * \
- 2 * kRegSize;
+#define VM_DecBranchOperands(operands_name) \
+ VM_DecBranchOperandsImpl(bytecode_data, &pc)
+static inline const iree_vm_register_remap_list_t* VM_DecBranchOperandsImpl(
+ const uint8_t* IREE_RESTRICT bytecode_data, iree_vm_source_offset_t* pc) {
+ VM_AlignPC(*pc, kRegSize);
+ const iree_vm_register_remap_list_t* list =
+ (const iree_vm_register_remap_list_t*)&bytecode_data[*pc];
+ *pc = *pc + kRegSize + list->size * 2 * kRegSize;
+ return list;
+}
#define VM_DecOperandRegI32(name) \
regs.i32[OP_I16(0) & regs.i32_mask]; \
pc += kRegSize;
@@ -222,10 +230,16 @@
®s.ref[OP_I16(0) & regs.ref_mask]; \
*(out_is_move) = OP_I16(0) & IREE_REF_REGISTER_MOVE_BIT; \
pc += kRegSize;
-#define VM_DecVariadicOperands(name) \
- (const iree_vm_register_list_t*)&bytecode_data[pc]; \
- pc += kRegSize + \
- ((const iree_vm_register_list_t*)&bytecode_data[pc])->size * kRegSize;
+#define VM_DecVariadicOperands(name) \
+ VM_DecVariadicOperandsImpl(bytecode_data, &pc)
+static inline const iree_vm_register_list_t* VM_DecVariadicOperandsImpl(
+ const uint8_t* IREE_RESTRICT bytecode_data, iree_vm_source_offset_t* pc) {
+ VM_AlignPC(*pc, kRegSize);
+ const iree_vm_register_list_t* list =
+ (const iree_vm_register_list_t*)&bytecode_data[*pc];
+ *pc = *pc + kRegSize + list->size * kRegSize;
+ return list;
+}
#define VM_DecResultRegI32(name) \
®s.i32[OP_I16(0) & regs.i32_mask]; \
pc += kRegSize;