Swapping `[]` -> `CD` and `.` -> `_` in the VM calling convention. (#4771)
This makes it easier to write simple 1:1 shim functions/wrappers without needing to special case characters or void/empty args/results.
Progress on #4736.
diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc
index d187069..eaa23c2 100644
--- a/bindings/python/pyiree/rt/function_abi.cc
+++ b/bindings/python/pyiree/rt/function_abi.cc
@@ -573,7 +573,11 @@
py::object this_object =
py::cast(this, py::return_value_policy::take_ownership);
if (descs.size() != f_results.size() || descs.size() != py_results.size()) {
- throw RaiseValueError("Mismatched RawUnpack() result arity");
+ std::string s = std::string("Mismatched RawUnpack() result arity; descs=") +
+ std::to_string(descs.size()) +
+ ", f_results=" + std::to_string(f_results.size()) +
+ ", py_results=" + std::to_string(py_results.size());
+ throw RaiseValueError(s.c_str());
}
for (size_t i = 0, e = descs.size(); i < e; ++i) {
const Description& desc = descs[i];
diff --git a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
index 316390d..ca6a8a9 100644
--- a/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
+++ b/iree/compiler/Dialect/VM/Target/C/CModuleTarget.cpp
@@ -78,16 +78,7 @@
if (!callingConvention) {
return funcOp.emitError("Couldn't create calling convention string");
}
- auto s = callingConvention.getValue();
- output << "call_";
- if (s.size() == 0) {
- output << "0_";
- } else {
- std::replace(s.begin(), s.end(), '.', '_');
- output << s;
- }
- output << "_shim";
-
+ output << "call_" << callingConvention.getValue() << "_shim";
return success();
}
diff --git a/iree/compiler/Dialect/VM/Target/CallingConventionUtils.cpp b/iree/compiler/Dialect/VM/Target/CallingConventionUtils.cpp
index 1bcbb8b..29578ec 100644
--- a/iree/compiler/Dialect/VM/Target/CallingConventionUtils.cpp
+++ b/iree/compiler/Dialect/VM/Target/CallingConventionUtils.cpp
@@ -66,9 +66,9 @@
LogicalResult encodeVariadicCallingConventionType(Operation *op, Type type,
SmallVectorImpl<char> &s) {
- s.push_back('[');
+ s.push_back('C');
auto result = encodeCallingConventionType(op, type, s);
- s.push_back(']');
+ s.push_back('D');
return result;
}
@@ -76,31 +76,37 @@
IREE::VM::ImportOp importOp) {
auto functionType = importOp.getType();
if (functionType.getNumInputs() == 0 && functionType.getNumResults() == 0) {
- return std::string{}; // Valid but empty.
+ return std::string("0v_v"); // Valid but empty.
}
SmallVector<char, 8> s = {'0'};
- for (int i = 0; i < functionType.getNumInputs(); ++i) {
- if (importOp.isFuncArgumentVariadic(i)) {
- if (failed(encodeVariadicCallingConventionType(
- importOp, functionType.getInput(i), s))) {
- return None;
- }
- } else {
- if (failed(encodeCallingConventionType(importOp, functionType.getInput(i),
- s))) {
- return None;
+ if (functionType.getNumInputs() > 0) {
+ for (int i = 0; i < functionType.getNumInputs(); ++i) {
+ if (importOp.isFuncArgumentVariadic(i)) {
+ if (failed(encodeVariadicCallingConventionType(
+ importOp, functionType.getInput(i), s))) {
+ return None;
+ }
+ } else {
+ if (failed(encodeCallingConventionType(importOp,
+ functionType.getInput(i), s))) {
+ return None;
+ }
}
}
+ } else {
+ s.push_back('v');
}
+ s.push_back('_');
if (functionType.getNumResults() > 0) {
- s.push_back('.');
for (int i = 0; i < functionType.getNumResults(); ++i) {
if (failed(encodeCallingConventionType(importOp,
functionType.getResult(i), s))) {
return None;
}
}
+ } else {
+ s.push_back('v');
}
return std::string(s.data(), s.size());
}
@@ -108,24 +114,30 @@
Optional<std::string> makeCallingConventionString(IREE::VM::FuncOp funcOp) {
auto functionType = funcOp.getType();
if (functionType.getNumInputs() == 0 && functionType.getNumResults() == 0) {
- return std::string{}; // Valid but empty.
+ return std::string("0v_v"); // Valid but empty.
}
SmallVector<char, 8> s = {'0'};
- for (int i = 0; i < functionType.getNumInputs(); ++i) {
- if (failed(
- encodeCallingConventionType(funcOp, functionType.getInput(i), s))) {
- return None;
+ if (functionType.getNumInputs() > 0) {
+ for (int i = 0; i < functionType.getNumInputs(); ++i) {
+ if (failed(encodeCallingConventionType(funcOp, functionType.getInput(i),
+ s))) {
+ return None;
+ }
}
+ } else {
+ s.push_back('v');
}
+ s.push_back('_');
if (functionType.getNumResults() > 0) {
- s.push_back('.');
for (int i = 0; i < functionType.getNumResults(); ++i) {
if (failed(encodeCallingConventionType(funcOp, functionType.getResult(i),
s))) {
return None;
}
}
+ } else {
+ s.push_back('v');
}
return std::string(s.data(), s.size());
}
diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c
index 0a77220..32594c2 100644
--- a/iree/vm/bytecode_dispatch.c
+++ b/iree/vm/bytecode_dispatch.c
@@ -200,6 +200,8 @@
const uint8_t* p = arguments.data;
for (iree_host_size_t i = 0; i < cconv_arguments.size; ++i) {
switch (cconv_arguments.data[i]) {
+ case IREE_VM_CCONV_TYPE_VOID:
+ break;
case IREE_VM_CCONV_TYPE_INT32: {
uint16_t dst_reg = i32_reg++;
memcpy(&callee_registers.i32[dst_reg & callee_registers.i32_mask], p,
@@ -241,6 +243,8 @@
for (iree_host_size_t i = 0; i < cconv_results.size; ++i) {
uint16_t src_reg = src_reg_list->registers[i];
switch (cconv_results.data[i]) {
+ case IREE_VM_CCONV_TYPE_VOID:
+ break;
case IREE_VM_CCONV_TYPE_INT32: {
memcpy(p, &callee_registers->i32[src_reg & callee_registers->i32_mask],
sizeof(int32_t));
@@ -380,6 +384,8 @@
for (iree_host_size_t i = 0, seg_i = 0, reg_i = 0; i < cconv_arguments.size;
++i, ++seg_i) {
switch (cconv_arguments.data[i]) {
+ case IREE_VM_CCONV_TYPE_VOID:
+ break;
case IREE_VM_CCONV_TYPE_INT32: {
memcpy(p,
&caller_registers.i32[src_reg_list->registers[reg_i++] &
@@ -423,6 +429,8 @@
++i) {
// TODO(benvanik): share with switch above.
switch (cconv_arguments.data[i]) {
+ case IREE_VM_CCONV_TYPE_VOID:
+ break;
case IREE_VM_CCONV_TYPE_INT32: {
memcpy(p,
&caller_registers.i32[src_reg_list->registers[reg_i++] &
@@ -484,6 +492,8 @@
++i) {
uint16_t dst_reg = dst_reg_list->registers[i];
switch (cconv_results.data[i]) {
+ case IREE_VM_CCONV_TYPE_VOID:
+ break;
case IREE_VM_CCONV_TYPE_INT32:
memcpy(&caller_registers.i32[dst_reg & caller_registers.i32_mask], p,
sizeof(int32_t));
diff --git a/iree/vm/bytecode_module_benchmark.cc b/iree/vm/bytecode_module_benchmark.cc
index 529ff3a..ed54bf4 100644
--- a/iree/vm/bytecode_module_benchmark.cc
+++ b/iree/vm/bytecode_module_benchmark.cc
@@ -44,7 +44,7 @@
static const iree_vm_native_export_descriptor_t
native_import_module_exports_[] = {
- {iree_make_cstring_view("add_1"), iree_make_cstring_view("0i.i"), 0,
+ {iree_make_cstring_view("add_1"), iree_make_cstring_view("0i_i"), 0,
NULL},
};
static const iree_vm_native_function_ptr_t native_import_module_funcs_[] = {
diff --git a/iree/vm/invocation.c b/iree/vm/invocation.c
index f223cd9..b8f3abb 100644
--- a/iree/vm/invocation.c
+++ b/iree/vm/invocation.c
@@ -23,24 +23,30 @@
iree_byte_span_t arguments) {
// We are 1:1 right now with no variadic args, so do a quick verification on
// the input list.
- if (!inputs) {
- if (cconv_arguments.size > 0) {
+ iree_host_size_t expected_input_count =
+ cconv_arguments.size > 0
+ ? (cconv_arguments.data[0] == 'v' ? 0 : cconv_arguments.size)
+ : 0;
+ if (IREE_UNLIKELY(!inputs)) {
+ if (IREE_UNLIKELY(expected_input_count > 0)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"no input provided to a function that has inputs");
}
return iree_ok_status();
- } else if (cconv_arguments.size != iree_vm_list_size(inputs)) {
+ } else if (IREE_UNLIKELY(expected_input_count != iree_vm_list_size(inputs))) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"input list and function mismatch; expected %zu "
"arguments but passed %zu",
- cconv_arguments.size, iree_vm_list_size(inputs));
+ expected_input_count, iree_vm_list_size(inputs));
}
uint8_t* p = arguments.data;
for (iree_host_size_t cconv_i = 0, arg_i = 0; cconv_i < cconv_arguments.size;
++cconv_i, ++arg_i) {
switch (cconv_arguments.data[cconv_i]) {
+ case IREE_VM_CCONV_TYPE_VOID:
+ break;
case IREE_VM_CCONV_TYPE_INT32: {
iree_vm_value_t value;
IREE_RETURN_IF_ERROR(iree_vm_list_get_value_as(
@@ -71,8 +77,12 @@
static iree_status_t iree_vm_invoke_marshal_outputs(
iree_string_view_t cconv_results, iree_byte_span_t results,
iree_vm_list_t* outputs) {
- if (!outputs) {
- if (cconv_results.size > 0) {
+ iree_host_size_t expected_output_count =
+ cconv_results.size > 0
+ ? (cconv_results.data[0] == 'v' ? 0 : cconv_results.size)
+ : 0;
+ if (IREE_UNLIKELY(!outputs)) {
+ if (IREE_UNLIKELY(expected_output_count > 0)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"no output provided to a function that has outputs");
@@ -83,12 +93,14 @@
// Resize the output list to hold all results (and kill anything that may
// have been in there).
IREE_RETURN_IF_ERROR(iree_vm_list_resize(outputs, 0));
- IREE_RETURN_IF_ERROR(iree_vm_list_resize(outputs, cconv_results.size));
+ IREE_RETURN_IF_ERROR(iree_vm_list_resize(outputs, expected_output_count));
uint8_t* p = results.data;
for (iree_host_size_t cconv_i = 0, arg_i = 0; cconv_i < cconv_results.size;
++cconv_i, ++arg_i) {
switch (cconv_results.data[cconv_i]) {
+ case IREE_VM_CCONV_TYPE_VOID:
+ break;
case IREE_VM_CCONV_TYPE_INT32: {
iree_vm_value_t value = iree_vm_value_make_i32(*(int32_t*)p);
IREE_RETURN_IF_ERROR(iree_vm_list_set_value(outputs, arg_i, &value));
diff --git a/iree/vm/module.c b/iree/vm/module.c
index a0a92f5..07aa8e1 100644
--- a/iree/vm/module.c
+++ b/iree/vm/module.c
@@ -35,7 +35,7 @@
"unsupported cconv version %c", cconv.data[0]);
}
iree_string_view_t cconv_body = iree_string_view_substr(cconv, 1, INTPTR_MAX);
- if (iree_string_view_split(cconv_body, '.', out_arguments, out_results) ==
+ if (iree_string_view_split(cconv_body, '_', out_arguments, out_results) ==
-1) {
*out_arguments = cconv_body;
}
@@ -57,6 +57,8 @@
for (iree_host_size_t i = 0, seg_i = 0; i < cconv_fragment.size;
++i, ++seg_i) {
switch (cconv_fragment.data[i]) {
+ case IREE_VM_CCONV_TYPE_VOID:
+ break;
case IREE_VM_CCONV_TYPE_INT32:
required_size += sizeof(int32_t);
break;
@@ -80,6 +82,8 @@
cconv_fragment.data[i] != IREE_VM_CCONV_TYPE_SPAN_END;
++i) {
switch (cconv_fragment.data[i]) {
+ case IREE_VM_CCONV_TYPE_VOID:
+ break;
case IREE_VM_CCONV_TYPE_INT32:
span_size += sizeof(int32_t);
break;
@@ -118,11 +122,13 @@
uint8_t* p = call->arguments.data;
for (iree_host_size_t i = 1; i < cconv.size; ++i) {
char c = cconv.data[i];
- if (c == '.') {
+ if (c == '_') {
// Switch to results.
p = call->results.data;
}
switch (c) {
+ case IREE_VM_CCONV_TYPE_VOID:
+ break;
case IREE_VM_CCONV_TYPE_INT32:
p += sizeof(int32_t);
break;
diff --git a/iree/vm/module.h b/iree/vm/module.h
index e62fd61..d8635e0 100644
--- a/iree/vm/module.h
+++ b/iree/vm/module.h
@@ -109,17 +109,17 @@
// - 'i': int32_t integer (i32)
// - 'I': int64_t integer (i64)
// - 'r': ref-counted type pointer (!vm.ref<?>)
- // - '[' ... ']': variadic list of flattened tuples of a specified type
- // - EOL or '.'
+ // - 'C' ... 'D': variadic list of flattened tuples of a specified type
+ // - EOL or '_'
// - Zero or more results:
// - 'i' or 'I'
// - 'r'
//
// Examples:
- // `0` or `0.`: () -> ()
- // `0i` or `0i.`: (i32) -> ()
- // `0ii[ii].i`: (i32, i32, tuple<i32, i32>...) -> i32
- // `0ir[ir].r`: (i32, !vm.ref<?>, tuple<i32, !vm.ref<?>>) -> !vm.ref<?>
+ // `0` or `0_`: () -> ()
+ // `0i` or `0i_`: (i32) -> ()
+ // `0iiCiiD_i`: (i32, i32, tuple<i32, i32>...) -> i32
+ // `0irCirD_r`: (i32, !vm.ref<?>, tuple<i32, !vm.ref<?>>) -> !vm.ref<?>
//
// Users of this field must verify the version prefix in the first byte before
// using the declaration.
@@ -207,11 +207,12 @@
iree_byte_span_t results;
} iree_vm_function_call_t;
+#define IREE_VM_CCONV_TYPE_VOID 'v'
#define IREE_VM_CCONV_TYPE_INT32 'i'
#define IREE_VM_CCONV_TYPE_INT64 'I'
#define IREE_VM_CCONV_TYPE_REF 'r'
-#define IREE_VM_CCONV_TYPE_SPAN_START '['
-#define IREE_VM_CCONV_TYPE_SPAN_END ']'
+#define IREE_VM_CCONV_TYPE_SPAN_START 'C'
+#define IREE_VM_CCONV_TYPE_SPAN_END 'D'
// Returns the arguments and results fragments from the function signature.
// Either may be empty if they have no values.
@@ -219,9 +220,11 @@
// Example:
// `` -> arguments = ``, results = ``
// `0` -> arguments = ``, results = ``
+// `0v` -> arguments = ``, results = ``
// `0ri` -> arguments = `ri`, results = ``
-// `0.ir` -> arguments = ``, results = `ir`
-// `0i[i].rr` -> arguments = `i[i]`, results = `rr`
+// `0_ir` -> arguments = ``, results = `ir`
+// `0v_ir` -> arguments = ``, results = `ir`
+// `0iCiD_rr` -> arguments = `iCiD`, results = `rr`
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_vm_function_call_get_cconv_fragments(
const iree_vm_function_signature_t* signature,
@@ -232,7 +235,7 @@
iree_vm_function_call_is_variadic_cconv(iree_string_view_t cconv);
// Returns the required size, in bytes, to store the data in the given cconv
-// fragment (like `iI[ri]r`).
+// fragment (like `iICriDr`).
//
// The provided |segment_size_list| is used for variadic arguments/results. Each
// entry represents one of the top level arguments with spans being flattened.
diff --git a/iree/vm/module_abi_packing.h b/iree/vm/module_abi_packing.h
index fdfa8dd..0b00b30 100644
--- a/iree/vm/module_abi_packing.h
+++ b/iree/vm/module_abi_packing.h
@@ -219,11 +219,11 @@
template <typename U>
struct cconv_map<absl::Span<U>> {
static constexpr const auto conv_chars = concat_literals(
- literal("["), cconv_map<typename impl::remove_cvref<U>::type>::conv_chars,
- literal("]"));
+ literal("C"), cconv_map<typename impl::remove_cvref<U>::type>::conv_chars,
+ literal("D"));
};
-template <typename Result, typename... Params>
+template <typename Result, size_t ParamsCount, typename... Params>
struct cconv_storage {
static const iree_string_view_t value() {
static constexpr const auto value = concat_literals(
@@ -231,7 +231,7 @@
concat_literals(
cconv_map<
typename impl::remove_cvref<Params>::type>::conv_chars...),
- literal("."),
+ literal("_"),
concat_literals(
cconv_map<typename impl::remove_cvref<Result>::type>::conv_chars));
static constexpr const auto str =
@@ -240,14 +240,38 @@
}
};
-template <typename... Params>
+template <typename Result>
+struct cconv_storage<Result, 0> {
+ static const iree_string_view_t value() {
+ static constexpr const auto value = concat_literals(
+ literal("0v_"),
+ concat_literals(
+ cconv_map<typename impl::remove_cvref<Result>::type>::conv_chars));
+ static constexpr const auto str =
+ iree_string_view_t{value.data(), value.size()};
+ return str;
+ }
+};
+
+template <size_t ParamsCount, typename... Params>
struct cconv_storage_void {
static const iree_string_view_t value() {
static constexpr const auto value = concat_literals(
literal("0"),
concat_literals(
cconv_map<
- typename impl::remove_cvref<Params>::type>::conv_chars...));
+ typename impl::remove_cvref<Params>::type>::conv_chars...),
+ literal("_v"));
+ static constexpr const auto str =
+ iree_string_view_t{value.data(), value.size()};
+ return str;
+ }
+};
+
+template <>
+struct cconv_storage_void<0> {
+ static const iree_string_view_t value() {
+ static constexpr const auto value = concat_literals(literal("0v_v"));
static constexpr const auto str =
iree_string_view_t{value.data(), value.size()};
return str;
@@ -628,7 +652,7 @@
absl::string_view name, StatusOr<Result> (Owner::*fn)(Params...)) {
using dispatch_functor_t = packing::DispatchFunctor<Owner, Result, Params...>;
return {{name.data(), name.size()},
- packing::cconv_storage<Result, Params...>::value(),
+ packing::cconv_storage<Result, sizeof...(Params), Params...>::value(),
(void (Owner::*)())fn,
&dispatch_functor_t::Call};
}
@@ -638,7 +662,7 @@
absl::string_view name, Status (Owner::*fn)(Params...)) {
using dispatch_functor_t = packing::DispatchFunctorVoid<Owner, Params...>;
return {{name.data(), name.size()},
- packing::cconv_storage_void<Params...>::value(),
+ packing::cconv_storage_void<sizeof...(Params), Params...>::value(),
(void (Owner::*)())fn,
&dispatch_functor_t::Call};
}
diff --git a/iree/vm/native_module_test.h b/iree/vm/native_module_test.h
index 35de2d5..e37c2c4 100644
--- a/iree/vm/native_module_test.h
+++ b/iree/vm/native_module_test.h
@@ -101,8 +101,8 @@
}
static const iree_vm_native_export_descriptor_t module_a_exports_[] = {
- {iree_make_cstring_view("add_1"), iree_make_cstring_view("0i.i"), 0, NULL},
- {iree_make_cstring_view("sub_1"), iree_make_cstring_view("0i.i"), 0, NULL},
+ {iree_make_cstring_view("add_1"), iree_make_cstring_view("0i_i"), 0, NULL},
+ {iree_make_cstring_view("sub_1"), iree_make_cstring_view("0i_i"), 0, NULL},
};
static const iree_vm_native_function_ptr_t module_a_funcs_[] = {
{(iree_vm_native_function_shim_t)call_shim_i32_i32,
@@ -257,7 +257,7 @@
{iree_make_cstring_view("key1"), iree_make_cstring_view("value1")},
};
static const iree_vm_native_export_descriptor_t module_b_exports_[] = {
- {iree_make_cstring_view("entry"), iree_make_cstring_view("0i.i"),
+ {iree_make_cstring_view("entry"), iree_make_cstring_view("0i_i"),
IREE_ARRAYSIZE(module_b_entry_attrs_), module_b_entry_attrs_},
};
static_assert(IREE_ARRAYSIZE(module_b_funcs_) ==
diff --git a/iree/vm/shims.h b/iree/vm/shims.h
index b4d06ec..1a49c77 100644
--- a/iree/vm/shims.h
+++ b/iree/vm/shims.h
@@ -19,22 +19,21 @@
#include "iree/vm/stack.h"
// see Calling convetion in module.h
-// We use the same format but replace '.' by '_'
// Variadic arguments are not supported
-// 0.
-typedef iree_status_t (*call_0__t)(iree_vm_stack_t* stack, void* module_ptr,
- void* module_state);
+// 0v_v
+typedef iree_status_t (*call_0v_v_t)(iree_vm_stack_t* stack, void* module_ptr,
+ void* module_state);
-static iree_status_t call_0__shim(iree_vm_stack_t* stack,
- const iree_vm_function_call_t* call,
- call_0__t target_fn, void* module,
- void* module_state,
- iree_vm_execution_result_t* out_result) {
+static iree_status_t call_0v_v_shim(iree_vm_stack_t* stack,
+ const iree_vm_function_call_t* call,
+ call_0v_v_t target_fn, void* module,
+ void* module_state,
+ iree_vm_execution_result_t* out_result) {
return target_fn(stack, module, module_state);
}
-// 0i.i
+// 0i_i
typedef iree_status_t (*call_0i_i_t)(iree_vm_stack_t* stack, void* module_ptr,
void* module_state, int32_t arg0,
int32_t* res0);
@@ -56,7 +55,8 @@
return target_fn(stack, module, module_state, args->arg0, &results->ret0);
}
-// 0ii.i
+
+// 0ii_i
typedef iree_status_t (*call_0ii_i_t)(iree_vm_stack_t* stack, void* module_ptr,
void* module_state, int32_t arg0,
int32_t arg1, int32_t* res0);