Refactor check_output_data and add MlOutput.
This CL moves the shared code in check_output_data into util.c and
creates a struct that represents the output of an execution. The values
of these structs will be set to CSRs after program completion in a
follow up CL.
As for what MlOutput points to: each model will need a shared struct
representing the format of the output, to be used in Rust
applications as well. An example is added to person_detection.
Change-Id: I7a36cd2fa85f2471671dd15347f5a4c5c1c849fd
diff --git a/samples/float_model/mnist.c b/samples/float_model/mnist.c
index 975dbea..98bfbbb 100644
--- a/samples/float_model/mnist.c
+++ b/samples/float_model/mnist.c
@@ -7,6 +7,7 @@
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "samples/util/util.h"
+#include "mnist.h"
// Compiled module embedded here to avoid file IO:
#include "samples/float_model/mnist_bytecode_module_dylib_c.h"
@@ -26,6 +27,8 @@
.model_name = "mnist",
};
+MnistOutput score;
+
const iree_const_byte_span_t load_bytecode_module_data() {
const struct iree_file_toc_t *module_file_toc =
samples_float_model_mnist_bytecode_module_dylib_create();
@@ -40,27 +43,27 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
// find the label index with best prediction
float best_out = 0.0;
int best_idx = -1;
- for (int i = 0; i < model->output_length[index_output]; ++i) {
- float out = ((float *)mapped_memory->contents.data)[i];
+ for (int i = 0; i < model->output_length[0]; ++i) {
+ float out = ((float *)buffers[0].contents.data)[i];
if (out > best_out) {
best_out = out;
best_idx = i;
}
}
+
+ score.best_out = best_out;
+ score.best_idx = best_idx;
+
LOG_INFO("Digit recognition result is: digit: %d", best_idx);
+
+ output->result = &score;
+ output->len = sizeof(score);
return result;
}
diff --git a/samples/float_model/mnist.h b/samples/float_model/mnist.h
new file mode 100644
index 0000000..67b2307
--- /dev/null
+++ b/samples/float_model/mnist.h
@@ -0,0 +1,11 @@
+#ifndef SAMPLES_MNIST_H
+#define SAMPLES_MNIST_H
+
+#include <stdint.h>
+
+typedef struct {
+ int best_idx;
+ float best_out;
+} MnistOutput;
+
+#endif
diff --git a/samples/float_model/mobilenet_v1.c b/samples/float_model/mobilenet_v1.c
index d354dfc..a0163f5 100644
--- a/samples/float_model/mobilenet_v1.c
+++ b/samples/float_model/mobilenet_v1.c
@@ -7,6 +7,7 @@
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "samples/util/util.h"
+#include "mobilenet_v1.h"
// Compiled module embedded here to avoid file IO:
#include "samples/float_model/mobilenet_input_c.h"
@@ -26,6 +27,8 @@
.model_name = "mobilenet_v1_0.25_224_float",
};
+MobilenetV1Output score;
+
const iree_const_byte_span_t load_bytecode_module_data() {
const struct iree_file_toc_t *module_file_toc =
samples_float_model_mobilenet_v1_bytecode_module_dylib_create();
@@ -40,27 +43,26 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
// find the label index with best prediction
float best_out = 0.0;
int best_idx = -1;
- for (int i = 0; i < model->output_length[index_output]; ++i) {
- float out = ((float *)mapped_memory->contents.data)[i];
+ for (int i = 0; i < model->output_length[0]; ++i) {
+ float out = ((float *)buffers[0].contents.data)[i];
if (out > best_out) {
best_out = out;
best_idx = i;
}
}
+ score.best_out = best_out;
+ score.best_idx = best_idx;
+
LOG_INFO("Image prediction result is: id: %d", best_idx + 1);
+
+ output->result = &score;
+ output->len = sizeof(score);
return result;
}
diff --git a/samples/float_model/mobilenet_v1.h b/samples/float_model/mobilenet_v1.h
new file mode 100644
index 0000000..1017ba6
--- /dev/null
+++ b/samples/float_model/mobilenet_v1.h
@@ -0,0 +1,11 @@
+#ifndef SAMPLES_MOBILENETV1_H
+#define SAMPLES_MOBILENETV1_H
+
+#include <stdint.h>
+
+typedef struct {
+ int best_idx;
+ float best_out;
+} MobilenetV1Output;
+
+#endif
diff --git a/samples/quant_model/barcode.c b/samples/quant_model/barcode.c
index d29ce8a..105c6da 100644
--- a/samples/quant_model/barcode.c
+++ b/samples/quant_model/barcode.c
@@ -46,16 +46,8 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
- iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
- return result;
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
+ return iree_ok_status();
}
diff --git a/samples/quant_model/daredevil.c b/samples/quant_model/daredevil.c
index 5b8754f..c20447c 100644
--- a/samples/quant_model/daredevil.c
+++ b/samples/quant_model/daredevil.c
@@ -44,16 +44,8 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
- iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
- return result;
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
+ return iree_ok_status();
}
diff --git a/samples/quant_model/fssd_25_8bit_v2.c b/samples/quant_model/fssd_25_8bit_v2.c
index ab71d4b..be979db 100644
--- a/samples/quant_model/fssd_25_8bit_v2.c
+++ b/samples/quant_model/fssd_25_8bit_v2.c
@@ -45,16 +45,8 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
- iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
- return result;
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
+ return iree_ok_status();
}
diff --git a/samples/quant_model/mobilenet_v1.c b/samples/quant_model/mobilenet_v1.c
index 3ece6b3..f55232d 100644
--- a/samples/quant_model/mobilenet_v1.c
+++ b/samples/quant_model/mobilenet_v1.c
@@ -7,6 +7,7 @@
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "samples/util/util.h"
+#include "mobilenet_v1.h"
// Compiled module embedded here to avoid file IO:
#include "samples/quant_model/mobilenet_quant_input_c.h"
@@ -26,6 +27,8 @@
.model_name = "mobilenet_v1_0.25_224_quant",
};
+MobilenetV1Output score;
+
const iree_const_byte_span_t load_bytecode_module_data() {
const struct iree_file_toc_t *module_file_toc =
samples_quant_model_mobilenet_v1_bytecode_module_dylib_create();
@@ -41,27 +44,26 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
// find the label index with best prediction
int best_out = 0;
int best_idx = -1;
- for (int i = 0; i < model->output_length[index_output]; ++i) {
- uint8_t out = ((uint8_t *)mapped_memory->contents.data)[i];
+ for (int i = 0; i < model->output_length[0]; ++i) {
+ uint8_t out = ((uint8_t *)buffers[0].contents.data)[i];
if (out > best_out) {
best_out = out;
best_idx = i;
}
}
+ score.best_out = best_out;
+ score.best_idx = best_idx;
+
LOG_INFO("Image prediction result is: id: %d", best_idx + 1);
+
+ output->result = &score;
+ output->len = sizeof(score);
return result;
}
diff --git a/samples/quant_model/mobilenet_v1.h b/samples/quant_model/mobilenet_v1.h
new file mode 100644
index 0000000..5277547
--- /dev/null
+++ b/samples/quant_model/mobilenet_v1.h
@@ -0,0 +1,11 @@
+#ifndef SAMPLES_MOBILENETV1_H
+#define SAMPLES_MOBILENETV1_H
+
+#include <stdint.h>
+
+typedef struct {
+ int best_idx;
+ int best_out;
+} MobilenetV1Output;
+
+#endif
diff --git a/samples/quant_model/mobilenet_v2.c b/samples/quant_model/mobilenet_v2.c
index 7451a58..612ce65 100644
--- a/samples/quant_model/mobilenet_v2.c
+++ b/samples/quant_model/mobilenet_v2.c
@@ -44,16 +44,8 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
- iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
- return result;
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
+ return iree_ok_status();
}
diff --git a/samples/quant_model/person_detection.c b/samples/quant_model/person_detection.c
index 22575dc..0f8e4ea 100644
--- a/samples/quant_model/person_detection.c
+++ b/samples/quant_model/person_detection.c
@@ -7,6 +7,7 @@
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "samples/util/util.h"
+#include "person_detection.h"
// Compiled module embedded here to avoid file IO:
#include "samples/quant_model/person_detection_bytecode_module_dylib_c.h"
@@ -26,6 +27,8 @@
.model_name = "person_detection_quant",
};
+PersonDetectionOutput detection;
+
const iree_const_byte_span_t load_bytecode_module_data() {
const struct iree_file_toc_t *module_file_toc =
samples_quant_model_person_detection_bytecode_module_dylib_create();
@@ -41,18 +44,19 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
- int8_t *data = (int8_t *)mapped_memory->contents.data;
- LOG_INFO("Output: Non-person Score: %d; Person Score: %d", data[0], data[1]);
+ int8_t *data = (int8_t *)buffers[0].contents.data;
+ detection.non_person_score = data[0];
+ detection.person_score = data[1];
+
+ LOG_INFO("Output: Non-person Score: %d; Person Score: %d",
+ detection.non_person_score,
+ detection.person_score);
+ output->result = &detection;
+ output->len = sizeof(detection);
+
return result;
}
diff --git a/samples/quant_model/person_detection.h b/samples/quant_model/person_detection.h
new file mode 100644
index 0000000..2e40ee1
--- /dev/null
+++ b/samples/quant_model/person_detection.h
@@ -0,0 +1,11 @@
+#ifndef SAMPLES_PERSON_DETECTION_H
+#define SAMPLES_PERSON_DETECTION_H
+
+#include <stdint.h>
+
+typedef struct {
+ int8_t non_person_score;
+ int8_t person_score;
+} PersonDetectionOutput;
+
+#endif
diff --git a/samples/quant_model/scenenet_v2.c b/samples/quant_model/scenenet_v2.c
index dbbf7fc..931c389 100644
--- a/samples/quant_model/scenenet_v2.c
+++ b/samples/quant_model/scenenet_v2.c
@@ -44,16 +44,8 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
- iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
- return result;
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
+ return iree_ok_status();
}
diff --git a/samples/quant_model/semantic_lift.c b/samples/quant_model/semantic_lift.c
index 02f6cb6..2ce1e95 100644
--- a/samples/quant_model/semantic_lift.c
+++ b/samples/quant_model/semantic_lift.c
@@ -44,16 +44,8 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
- iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
- return result;
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
+ return iree_ok_status();
}
diff --git a/samples/quant_model/voice_commands.c b/samples/quant_model/voice_commands.c
index 3444af0..4fdbbee 100644
--- a/samples/quant_model/voice_commands.c
+++ b/samples/quant_model/voice_commands.c
@@ -44,16 +44,8 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
- iree_status_t result = iree_ok_status();
- if (index_output > model->num_output ||
- mapped_memory->contents.data_length / model->output_size_bytes !=
- model->output_length[index_output]) {
- result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
- }
- LOG_INFO("Output #%d data length: %d", index_output,
- mapped_memory->contents.data_length / model->output_size_bytes);
- return result;
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
+ return iree_ok_status();
}
diff --git a/samples/simple_vec_mul/float_vec.c b/samples/simple_vec_mul/float_vec.c
index 197684f..15c50e2 100644
--- a/samples/simple_vec_mul/float_vec.c
+++ b/samples/simple_vec_mul/float_vec.c
@@ -60,13 +60,13 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
iree_status_t result = iree_ok_status();
- for (int i = 0; i < mapped_memory->contents.data_length / sizeof(float);
+ for (int i = 0; i < buffers[0].contents.data_length / sizeof(float);
++i) {
- if (((const float *)mapped_memory->contents.data)[i] != i * i / 8.0f) {
+ if (((const float *)buffers[0].contents.data)[i] != i * i / 8.0f) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches");
break;
}
diff --git a/samples/simple_vec_mul/int_vec.c b/samples/simple_vec_mul/int_vec.c
index 824a031..6abf697 100644
--- a/samples/simple_vec_mul/int_vec.c
+++ b/samples/simple_vec_mul/int_vec.c
@@ -60,13 +60,13 @@
return result;
}
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output) {
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output) {
iree_status_t result = iree_ok_status();
- for (int i = 0; i < mapped_memory->contents.data_length / sizeof(int32_t);
+ for (int i = 0; i < buffers[0].contents.data_length / sizeof(int32_t);
++i) {
- if (((const int32_t *)mapped_memory->contents.data)[i] != (i >> 1) * i) {
+ if (((const int32_t *)buffers[0].contents.data)[i] != (i >> 1) * i) {
result = iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches");
break;
}
diff --git a/samples/util/model_api.h b/samples/util/model_api.h
index e4a4f23..c5671c7 100644
--- a/samples/util/model_api.h
+++ b/samples/util/model_api.h
@@ -29,6 +29,11 @@
char model_name[];
} MlModel;
+typedef struct {
+ void* result;
+ uint32_t len;
+} MlOutput;
+
// Load the VM bytecode module from the embedded c library into memory.
const iree_const_byte_span_t load_bytecode_module_data();
@@ -44,10 +49,11 @@
// randomly generated stream, or a pointer from the sensor/ISP output.
iree_status_t load_input_data(const MlModel *model, void **buffer);
-// Check the ML execution output, and prepare the final data to be sent to the
-// host with post processing. The final format is model dependent.
-iree_status_t check_output_data(const MlModel *model,
- iree_hal_buffer_mapping_t *mapped_memory,
- int index_output);
+// Process the ML execution output into the final data to be sent to the
+// host. The final format is model dependent, so the address and size
+// are returned via `output.`
+iree_status_t process_output(const MlModel *model,
+ iree_hal_buffer_mapping_t *buffers,
+ MlOutput *output);
#endif // SW_VEC_IREE_SAMPLES_UTIL_MODEL_API_H_
diff --git a/samples/util/util.c b/samples/util/util.c
index cb674a0..b347358 100644
--- a/samples/util/util.c
+++ b/samples/util/util.c
@@ -138,6 +138,8 @@
iree_allocator_system());
}
+ // Validate output and gather buffers.
+ iree_hal_buffer_mapping_t mapped_memories[MAX_MODEL_OUTPUTS] = {{0}};
for (int index_output = 0; index_output < model->num_output; index_output++) {
iree_hal_buffer_view_t *ret_buffer_view = NULL;
if (iree_status_is_ok(result)) {
@@ -149,20 +151,34 @@
"can't find return buffer view");
}
}
- // Read back the results and ensure we got the right values.
- iree_hal_buffer_mapping_t mapped_memory;
if (iree_status_is_ok(result)) {
result = iree_hal_buffer_map_range(
iree_hal_buffer_view_buffer(ret_buffer_view),
IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_READ, 0,
- IREE_WHOLE_BUFFER, &mapped_memory);
+ IREE_WHOLE_BUFFER, &mapped_memories[index_output]);
}
+
if (iree_status_is_ok(result)) {
- result = check_output_data(model, &mapped_memory, index_output);
- iree_hal_buffer_unmap_range(&mapped_memory);
+ if (index_output > model->num_output ||
+ mapped_memories[index_output].contents.data_length / model->output_size_bytes !=
+ model->output_length[index_output]) {
+ result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
+ }
}
}
+ // Post-process memory into model output.
+ if (iree_status_is_ok(result)) {
+ MlOutput output = {.result = NULL, .len = 0};
+ result = process_output(model, mapped_memories, &output);
+ // TODO(jesionowski): Populate CSRs with `output` after validating result.
+ }
+
+ for (int index_output = 0; index_output < model->num_output; index_output++) {
+ if (mapped_memories[index_output].contents.data != NULL) {
+ iree_hal_buffer_unmap_range(&mapped_memories[index_output]);
+ }
+ }
iree_vm_list_release(inputs);
iree_vm_list_release(outputs);
for (int i = 0; i < model->num_input; ++i) {