Refactor check_output_data and add MlOutput. This CL moves the shared code in check_output_data into util.c and creates a struct that represents the output of an execution. The values of these structs will be set to CSRs after program completion in a follow up CL. As for what MlOutput points to: each model will need a shared struct representing the format of the output, to be used in Rust applications as well. An example is added to person_detection. Change-Id: I7a36cd2fa85f2471671dd15347f5a4c5c1c849fd
diff --git a/samples/float_model/mnist.c b/samples/float_model/mnist.c index 975dbea..98bfbbb 100644 --- a/samples/float_model/mnist.c +++ b/samples/float_model/mnist.c
@@ -7,6 +7,7 @@ #include "iree/base/api.h" #include "iree/hal/api.h" #include "samples/util/util.h" +#include "mnist.h" // Compiled module embedded here to avoid file IO: #include "samples/float_model/mnist_bytecode_module_dylib_c.h" @@ -26,6 +27,8 @@ .model_name = "mnist", }; +MnistOutput score; + const iree_const_byte_span_t load_bytecode_module_data() { const struct iree_file_toc_t *module_file_toc = samples_float_model_mnist_bytecode_module_dylib_create(); @@ -40,27 +43,27 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); // find the label index with best prediction float best_out = 0.0; int best_idx = -1; - for (int i = 0; i < model->output_length[index_output]; ++i) { - float out = ((float *)mapped_memory->contents.data)[i]; + for (int i = 0; i < model->output_length[0]; ++i) { + float out = ((float *)buffers[0].contents.data)[i]; if (out > best_out) { best_out = out; best_idx = i; } } + + score.best_out = best_out; + score.best_idx = best_idx; + LOG_INFO("Digit recognition result is: digit: %d", best_idx); + + output->result = &score; + output->len = sizeof(score); return result; }
diff --git a/samples/float_model/mnist.h b/samples/float_model/mnist.h new file mode 100644 index 0000000..67b2307 --- /dev/null +++ b/samples/float_model/mnist.h
@@ -0,0 +1,11 @@ +#ifndef SAMPLES_MNIST_H +#define SAMPLES_MNIST_H + +#include <stdint.h> + +typedef struct { + int best_idx; + float best_out; +} MnistOutput; + +#endif
diff --git a/samples/float_model/mobilenet_v1.c b/samples/float_model/mobilenet_v1.c index d354dfc..a0163f5 100644 --- a/samples/float_model/mobilenet_v1.c +++ b/samples/float_model/mobilenet_v1.c
@@ -7,6 +7,7 @@ #include "iree/base/api.h" #include "iree/hal/api.h" #include "samples/util/util.h" +#include "mobilenet_v1.h" // Compiled module embedded here to avoid file IO: #include "samples/float_model/mobilenet_input_c.h" @@ -26,6 +27,8 @@ .model_name = "mobilenet_v1_0.25_224_float", }; +MobilenetV1Output score; + const iree_const_byte_span_t load_bytecode_module_data() { const struct iree_file_toc_t *module_file_toc = samples_float_model_mobilenet_v1_bytecode_module_dylib_create(); @@ -40,27 +43,26 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); // find the label index with best prediction float best_out = 0.0; int best_idx = -1; - for (int i = 0; i < model->output_length[index_output]; ++i) { - float out = ((float *)mapped_memory->contents.data)[i]; + for (int i = 0; i < model->output_length[0]; ++i) { + float out = ((float *)buffers[0].contents.data)[i]; if (out > best_out) { best_out = out; best_idx = i; } } + score.best_out = best_out; + score.best_idx = best_idx; + LOG_INFO("Image prediction result is: id: %d", best_idx + 1); + + output->result = &score; + output->len = sizeof(score); return result; }
diff --git a/samples/float_model/mobilenet_v1.h b/samples/float_model/mobilenet_v1.h new file mode 100644 index 0000000..1017ba6 --- /dev/null +++ b/samples/float_model/mobilenet_v1.h
@@ -0,0 +1,11 @@ +#ifndef SAMPLES_MOBILENETV1_H +#define SAMPLES_MOBILENETV1_H + +#include <stdint.h> + +typedef struct { + int best_idx; + float best_out; +} MobilenetV1Output; + +#endif
diff --git a/samples/quant_model/barcode.c b/samples/quant_model/barcode.c index d29ce8a..105c6da 100644 --- a/samples/quant_model/barcode.c +++ b/samples/quant_model/barcode.c
@@ -46,16 +46,8 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { - iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); - return result; +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { + return iree_ok_status(); }
diff --git a/samples/quant_model/daredevil.c b/samples/quant_model/daredevil.c index 5b8754f..c20447c 100644 --- a/samples/quant_model/daredevil.c +++ b/samples/quant_model/daredevil.c
@@ -44,16 +44,8 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { - iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); - return result; +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { + return iree_ok_status(); }
diff --git a/samples/quant_model/fssd_25_8bit_v2.c b/samples/quant_model/fssd_25_8bit_v2.c index ab71d4b..be979db 100644 --- a/samples/quant_model/fssd_25_8bit_v2.c +++ b/samples/quant_model/fssd_25_8bit_v2.c
@@ -45,16 +45,8 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { - iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); - return result; +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { + return iree_ok_status(); }
diff --git a/samples/quant_model/mobilenet_v1.c b/samples/quant_model/mobilenet_v1.c index 3ece6b3..f55232d 100644 --- a/samples/quant_model/mobilenet_v1.c +++ b/samples/quant_model/mobilenet_v1.c
@@ -7,6 +7,7 @@ #include "iree/base/api.h" #include "iree/hal/api.h" #include "samples/util/util.h" +#include "mobilenet_v1.h" // Compiled module embedded here to avoid file IO: #include "samples/quant_model/mobilenet_quant_input_c.h" @@ -26,6 +27,8 @@ .model_name = "mobilenet_v1_0.25_224_quant", }; +MobilenetV1Output score; + const iree_const_byte_span_t load_bytecode_module_data() { const struct iree_file_toc_t *module_file_toc = samples_quant_model_mobilenet_v1_bytecode_module_dylib_create(); @@ -41,27 +44,26 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); // find the label index with best prediction int best_out = 0; int best_idx = -1; - for (int i = 0; i < model->output_length[index_output]; ++i) { - uint8_t out = ((uint8_t *)mapped_memory->contents.data)[i]; + for (int i = 0; i < model->output_length[0]; ++i) { + uint8_t out = ((uint8_t *)buffers[0].contents.data)[i]; if (out > best_out) { best_out = out; best_idx = i; } } + score.best_out = best_out; + score.best_idx = best_idx; + LOG_INFO("Image prediction result is: id: %d", best_idx + 1); + + output->result = &score; + output->len = sizeof(score); return result; }
diff --git a/samples/quant_model/mobilenet_v1.h b/samples/quant_model/mobilenet_v1.h new file mode 100644 index 0000000..5277547 --- /dev/null +++ b/samples/quant_model/mobilenet_v1.h
@@ -0,0 +1,11 @@ +#ifndef SAMPLES_MOBILENETV1_H +#define SAMPLES_MOBILENETV1_H + +#include <stdint.h> + +typedef struct { + int best_idx; + int best_out; +} MobilenetV1Output; + +#endif
diff --git a/samples/quant_model/mobilenet_v2.c b/samples/quant_model/mobilenet_v2.c index 7451a58..612ce65 100644 --- a/samples/quant_model/mobilenet_v2.c +++ b/samples/quant_model/mobilenet_v2.c
@@ -44,16 +44,8 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { - iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); - return result; +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { + return iree_ok_status(); }
diff --git a/samples/quant_model/person_detection.c b/samples/quant_model/person_detection.c index 22575dc..0f8e4ea 100644 --- a/samples/quant_model/person_detection.c +++ b/samples/quant_model/person_detection.c
@@ -7,6 +7,7 @@ #include "iree/base/api.h" #include "iree/hal/api.h" #include "samples/util/util.h" +#include "person_detection.h" // Compiled module embedded here to avoid file IO: #include "samples/quant_model/person_detection_bytecode_module_dylib_c.h" @@ -26,6 +27,8 @@ .model_name = "person_detection_quant", }; +PersonDetectionOutput detection; + const iree_const_byte_span_t load_bytecode_module_data() { const struct iree_file_toc_t *module_file_toc = samples_quant_model_person_detection_bytecode_module_dylib_create(); @@ -41,18 +44,19 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); - int8_t *data = (int8_t *)mapped_memory->contents.data; - LOG_INFO("Output: Non-person Score: %d; Person Score: %d", data[0], data[1]); + int8_t *data = (int8_t *)buffers[0].contents.data; + detection.non_person_score = data[0]; + detection.person_score = data[1]; + + LOG_INFO("Output: Non-person Score: %d; Person Score: %d", + detection.non_person_score, + detection.person_score); + output->result = &detection; + output->len = sizeof(detection); + return result; }
diff --git a/samples/quant_model/person_detection.h b/samples/quant_model/person_detection.h new file mode 100644 index 0000000..2e40ee1 --- /dev/null +++ b/samples/quant_model/person_detection.h
@@ -0,0 +1,11 @@ +#ifndef SAMPLES_PERSON_DETECTION_H +#define SAMPLES_PERSON_DETECTION_H + +#include <stdint.h> + +typedef struct { + int8_t non_person_score; + int8_t person_score; +} PersonDetectionOutput; + +#endif
diff --git a/samples/quant_model/scenenet_v2.c b/samples/quant_model/scenenet_v2.c index dbbf7fc..931c389 100644 --- a/samples/quant_model/scenenet_v2.c +++ b/samples/quant_model/scenenet_v2.c
@@ -44,16 +44,8 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { - iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); - return result; +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { + return iree_ok_status(); }
diff --git a/samples/quant_model/semantic_lift.c b/samples/quant_model/semantic_lift.c index 02f6cb6..2ce1e95 100644 --- a/samples/quant_model/semantic_lift.c +++ b/samples/quant_model/semantic_lift.c
@@ -44,16 +44,8 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { - iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); - return result; +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { + return iree_ok_status(); }
diff --git a/samples/quant_model/voice_commands.c b/samples/quant_model/voice_commands.c index 3444af0..4fdbbee 100644 --- a/samples/quant_model/voice_commands.c +++ b/samples/quant_model/voice_commands.c
@@ -44,16 +44,8 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { - iree_status_t result = iree_ok_status(); - if (index_output > model->num_output || - mapped_memory->contents.data_length / model->output_size_bytes != - model->output_length[index_output]) { - result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); - } - LOG_INFO("Output #%d data length: %d", index_output, - mapped_memory->contents.data_length / model->output_size_bytes); - return result; +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { + return iree_ok_status(); }
diff --git a/samples/simple_vec_mul/float_vec.c b/samples/simple_vec_mul/float_vec.c index 197684f..15c50e2 100644 --- a/samples/simple_vec_mul/float_vec.c +++ b/samples/simple_vec_mul/float_vec.c
@@ -60,13 +60,13 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { iree_status_t result = iree_ok_status(); - for (int i = 0; i < mapped_memory->contents.data_length / sizeof(float); + for (int i = 0; i < buffers[0].contents.data_length / sizeof(float); ++i) { - if (((const float *)mapped_memory->contents.data)[i] != i * i / 8.0f) { + if (((const float *)buffers[0].contents.data)[i] != i * i / 8.0f) { result = iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches"); break; }
diff --git a/samples/simple_vec_mul/int_vec.c b/samples/simple_vec_mul/int_vec.c index 824a031..6abf697 100644 --- a/samples/simple_vec_mul/int_vec.c +++ b/samples/simple_vec_mul/int_vec.c
@@ -60,13 +60,13 @@ return result; } -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output) { +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output) { iree_status_t result = iree_ok_status(); - for (int i = 0; i < mapped_memory->contents.data_length / sizeof(int32_t); + for (int i = 0; i < buffers[0].contents.data_length / sizeof(int32_t); ++i) { - if (((const int32_t *)mapped_memory->contents.data)[i] != (i >> 1) * i) { + if (((const int32_t *)buffers[0].contents.data)[i] != (i >> 1) * i) { result = iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches"); break; }
diff --git a/samples/util/model_api.h b/samples/util/model_api.h index e4a4f23..c5671c7 100644 --- a/samples/util/model_api.h +++ b/samples/util/model_api.h
@@ -29,6 +29,11 @@ char model_name[]; } MlModel; +typedef struct { + void* result; + uint32_t len; +} MlOutput; + // Load the VM bytecode module from the embedded c library into memory. const iree_const_byte_span_t load_bytecode_module_data(); @@ -44,10 +49,11 @@ // randomly generated stream, or a pointer from the sensor/ISP output. iree_status_t load_input_data(const MlModel *model, void **buffer); -// Check the ML execution output, and prepare the final data to be sent to the -// host with post processing. The final format is model dependent. -iree_status_t check_output_data(const MlModel *model, - iree_hal_buffer_mapping_t *mapped_memory, - int index_output); +// Process the ML execution output into the final data to be sent to the +// host. The final format is model dependent, so the address and size +// are returned via `output.` +iree_status_t process_output(const MlModel *model, + iree_hal_buffer_mapping_t *buffers, + MlOutput *output); #endif // SW_VEC_IREE_SAMPLES_UTIL_MODEL_API_H_
diff --git a/samples/util/util.c b/samples/util/util.c index cb674a0..b347358 100644 --- a/samples/util/util.c +++ b/samples/util/util.c
@@ -138,6 +138,8 @@ iree_allocator_system()); } + // Validate output and gather buffers. + iree_hal_buffer_mapping_t mapped_memories[MAX_MODEL_OUTPUTS] = {{0}}; for (int index_output = 0; index_output < model->num_output; index_output++) { iree_hal_buffer_view_t *ret_buffer_view = NULL; if (iree_status_is_ok(result)) { @@ -149,20 +151,34 @@ "can't find return buffer view"); } } - // Read back the results and ensure we got the right values. - iree_hal_buffer_mapping_t mapped_memory; if (iree_status_is_ok(result)) { result = iree_hal_buffer_map_range( iree_hal_buffer_view_buffer(ret_buffer_view), IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_READ, 0, - IREE_WHOLE_BUFFER, &mapped_memory); + IREE_WHOLE_BUFFER, &mapped_memories[index_output]); } + if (iree_status_is_ok(result)) { - result = check_output_data(model, &mapped_memory, index_output); - iree_hal_buffer_unmap_range(&mapped_memory); + if (index_output > model->num_output || + mapped_memories[index_output].contents.data_length / model->output_size_bytes != + model->output_length[index_output]) { + result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches"); + } } } + // Post-process memory into model output. + if (iree_status_is_ok(result)) { + MlOutput output = {.result = NULL, .len = 0}; + result = process_output(model, mapped_memories, &output); + // TODO(jesionowski): Populate CSRs with `output` after validating result. + } + + for (int index_output = 0; index_output < model->num_output; index_output++) { + if (mapped_memories[index_output].contents.data != NULL) { + iree_hal_buffer_unmap_range(&mapped_memories[index_output]); + } + } iree_vm_list_release(inputs); iree_vm_list_release(outputs); for (int i = 0; i < model->num_input; ++i) {