Fixing vm.select.ref for moves into the same register. (#6137)
Previous behavior would lead to (safe) null access violations if the
non-selected value was moved into the same register:
%0 = vm.select.ref %cond, %t, %f
where: %0 and %t/%f alias a register and one or both %t/%f are moves.
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index 8765945..1019149 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -1230,12 +1230,16 @@
// Select LHS.
IREE_RETURN_IF_ERROR(iree_vm_ref_retain_or_move_checked(
true_value_is_move, true_value, type_def->ref_type, result));
- if (false_value_is_move) iree_vm_ref_release(false_value);
+ if (false_value_is_move && false_value != result) {
+ iree_vm_ref_release(false_value);
+ }
} else {
// Select RHS.
IREE_RETURN_IF_ERROR(iree_vm_ref_retain_or_move_checked(
false_value_is_move, false_value, type_def->ref_type, result));
- if (true_value_is_move) iree_vm_ref_release(true_value);
+ if (true_value_is_move && true_value != result) {
+ iree_vm_ref_release(true_value);
+ }
}
});
diff --git a/iree/vm/ref.h b/iree/vm/ref.h
index 191a6d4..5ee6343 100644
--- a/iree/vm/ref.h
+++ b/iree/vm/ref.h
@@ -170,8 +170,12 @@
// Checks that the given reference-counted pointer |ref| is of |type|.
static inline iree_status_t iree_vm_ref_check(const iree_vm_ref_t ref,
iree_vm_ref_type_t type) {
- return ref.type == type ? iree_ok_status()
- : iree_make_status(IREE_STATUS_INVALID_ARGUMENT);
+ return IREE_LIKELY(ref.type == type)
+ ? iree_ok_status()
+ : iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ ref.type == IREE_VM_REF_TYPE_NULL
+ ? "ref is null"
+ : "ref type mismatch");
}
// Retains the reference-counted pointer |ref|.
diff --git a/iree/vm/test/assignment_ops.mlir b/iree/vm/test/assignment_ops.mlir
index ef10708..967ceb8 100644
--- a/iree/vm/test/assignment_ops.mlir
+++ b/iree/vm/test/assignment_ops.mlir
@@ -16,4 +16,17 @@
vm.check.eq %v2, %c0, "1 ? 0 : 1 = 0" : i32
vm.return
}
+
+ vm.export @test_select_ref
+ vm.func @test_select_ref() {
+ %c0 = vm.const.i32 0 : i32
+ %list0 = vm.list.alloc %c0 : (i32) -> !vm.list<i8>
+ %c1 = vm.const.i32 1 : i32
+ %list1 = vm.list.alloc %c1 : (i32) -> !vm.list<i8>
+ %cond = vm.const.i32 0 : i32
+ %cond_dno = iree.do_not_optimize(%cond) : i32
+ %list = vm.select.ref %cond_dno, %list0, %list1 : !vm.list<i8>
+ vm.check.eq %list, %list1, "0 ? list0 : list1 = list1" : !vm.list<i8>
+ vm.return
+ }
}