Refactor check_output_data and add MlOutput.

This CL moves the shared code in check_output_data into util.c and
creates a struct that represents the output of an execution. The values
of these structs will be set to CSRs after program completion in a
follow up CL.

As for what MlOutput points to: each model will need a shared struct
representing the format of the output, to be used in Rust
applications as well. An example is added to person_detection.

Change-Id: I7a36cd2fa85f2471671dd15347f5a4c5c1c849fd
diff --git a/samples/float_model/mnist.c b/samples/float_model/mnist.c
index 975dbea..98bfbbb 100644
--- a/samples/float_model/mnist.c
+++ b/samples/float_model/mnist.c
@@ -7,6 +7,7 @@
 #include "iree/base/api.h"
 #include "iree/hal/api.h"
 #include "samples/util/util.h"
+#include "mnist.h"
 
 // Compiled module embedded here to avoid file IO:
 #include "samples/float_model/mnist_bytecode_module_dylib_c.h"
@@ -26,6 +27,8 @@
     .model_name = "mnist",
 };
 
+MnistOutput score;
+
 const iree_const_byte_span_t load_bytecode_module_data() {
   const struct iree_file_toc_t *module_file_toc =
       samples_float_model_mnist_bytecode_module_dylib_create();
@@ -40,27 +43,27 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
+iree_status_t process_output(const MlModel *model,
+                  iree_hal_buffer_mapping_t *buffers,
+                  MlOutput *output) {
   iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
   // find the label index with best prediction
   float best_out = 0.0;
   int best_idx = -1;
-  for (int i = 0; i < model->output_length[index_output]; ++i) {
-    float out = ((float *)mapped_memory->contents.data)[i];
+  for (int i = 0; i < model->output_length[0]; ++i) {
+    float out = ((float *)buffers[0].contents.data)[i];
     if (out > best_out) {
       best_out = out;
       best_idx = i;
     }
   }
+
+  score.best_out = best_out;
+  score.best_idx = best_idx;
+
   LOG_INFO("Digit recognition result is: digit: %d", best_idx);
+
+  output->result = &score;
+  output->len = sizeof(score);
   return result;
 }
diff --git a/samples/float_model/mnist.h b/samples/float_model/mnist.h
new file mode 100644
index 0000000..67b2307
--- /dev/null
+++ b/samples/float_model/mnist.h
@@ -0,0 +1,11 @@
+#ifndef SAMPLES_MNIST_H
+#define SAMPLES_MNIST_H
+
+#include <stdint.h>
+
+typedef struct {
+    int best_idx;
+    float best_out;
+} MnistOutput;
+
+#endif
diff --git a/samples/float_model/mobilenet_v1.c b/samples/float_model/mobilenet_v1.c
index d354dfc..a0163f5 100644
--- a/samples/float_model/mobilenet_v1.c
+++ b/samples/float_model/mobilenet_v1.c
@@ -7,6 +7,7 @@
 #include "iree/base/api.h"
 #include "iree/hal/api.h"
 #include "samples/util/util.h"
+#include "mobilenet_v1.h"
 
 // Compiled module embedded here to avoid file IO:
 #include "samples/float_model/mobilenet_input_c.h"
@@ -26,6 +27,8 @@
     .model_name = "mobilenet_v1_0.25_224_float",
 };
 
+MobilenetV1Output score;
+
 const iree_const_byte_span_t load_bytecode_module_data() {
   const struct iree_file_toc_t *module_file_toc =
       samples_float_model_mobilenet_v1_bytecode_module_dylib_create();
@@ -40,27 +43,26 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
+iree_status_t process_output(const MlModel *model,
+                  iree_hal_buffer_mapping_t *buffers,
+                  MlOutput *output) {
   iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
   // find the label index with best prediction
   float best_out = 0.0;
   int best_idx = -1;
-  for (int i = 0; i < model->output_length[index_output]; ++i) {
-    float out = ((float *)mapped_memory->contents.data)[i];
+  for (int i = 0; i < model->output_length[0]; ++i) {
+    float out = ((float *)buffers[0].contents.data)[i];
     if (out > best_out) {
       best_out = out;
       best_idx = i;
     }
   }
+  score.best_out = best_out;
+  score.best_idx = best_idx;
+
   LOG_INFO("Image prediction result is: id: %d", best_idx + 1);
+
+  output->result = &score;
+  output->len = sizeof(score);
   return result;
 }
diff --git a/samples/float_model/mobilenet_v1.h b/samples/float_model/mobilenet_v1.h
new file mode 100644
index 0000000..1017ba6
--- /dev/null
+++ b/samples/float_model/mobilenet_v1.h
@@ -0,0 +1,11 @@
+#ifndef SAMPLES_MOBILENETV1_H
+#define SAMPLES_MOBILENETV1_H
+
+#include <stdint.h>
+
+typedef struct {
+    int best_idx;
+    float best_out;
+} MobilenetV1Output;
+
+#endif
diff --git a/samples/quant_model/barcode.c b/samples/quant_model/barcode.c
index d29ce8a..105c6da 100644
--- a/samples/quant_model/barcode.c
+++ b/samples/quant_model/barcode.c
@@ -46,16 +46,8 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
-  iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
-  return result;
+iree_status_t process_output(const MlModel *model,
+                              iree_hal_buffer_mapping_t *buffers,
+                              MlOutput *output) {
+  return iree_ok_status();
 }
diff --git a/samples/quant_model/daredevil.c b/samples/quant_model/daredevil.c
index 5b8754f..c20447c 100644
--- a/samples/quant_model/daredevil.c
+++ b/samples/quant_model/daredevil.c
@@ -44,16 +44,8 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
-  iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
-  return result;
+iree_status_t process_output(const MlModel *model,
+                              iree_hal_buffer_mapping_t *buffers,
+                              MlOutput *output) {
+  return iree_ok_status();
 }
diff --git a/samples/quant_model/fssd_25_8bit_v2.c b/samples/quant_model/fssd_25_8bit_v2.c
index ab71d4b..be979db 100644
--- a/samples/quant_model/fssd_25_8bit_v2.c
+++ b/samples/quant_model/fssd_25_8bit_v2.c
@@ -45,16 +45,8 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
-  iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
-  return result;
+iree_status_t process_output(const MlModel *model,
+                              iree_hal_buffer_mapping_t *buffers,
+                              MlOutput *output) {
+  return iree_ok_status();
 }
diff --git a/samples/quant_model/mobilenet_v1.c b/samples/quant_model/mobilenet_v1.c
index 3ece6b3..f55232d 100644
--- a/samples/quant_model/mobilenet_v1.c
+++ b/samples/quant_model/mobilenet_v1.c
@@ -7,6 +7,7 @@
 #include "iree/base/api.h"
 #include "iree/hal/api.h"
 #include "samples/util/util.h"
+#include "mobilenet_v1.h"
 
 // Compiled module embedded here to avoid file IO:
 #include "samples/quant_model/mobilenet_quant_input_c.h"
@@ -26,6 +27,8 @@
     .model_name = "mobilenet_v1_0.25_224_quant",
 };
 
+MobilenetV1Output score;
+
 const iree_const_byte_span_t load_bytecode_module_data() {
   const struct iree_file_toc_t *module_file_toc =
       samples_quant_model_mobilenet_v1_bytecode_module_dylib_create();
@@ -41,27 +44,26 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
+iree_status_t process_output(const MlModel *model,
+                  iree_hal_buffer_mapping_t *buffers,
+                  MlOutput *output) {
   iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
   // find the label index with best prediction
   int best_out = 0;
   int best_idx = -1;
-  for (int i = 0; i < model->output_length[index_output]; ++i) {
-    uint8_t out = ((uint8_t *)mapped_memory->contents.data)[i];
+  for (int i = 0; i < model->output_length[0]; ++i) {
+    uint8_t out = ((uint8_t *)buffers[0].contents.data)[i];
     if (out > best_out) {
       best_out = out;
       best_idx = i;
     }
   }
+  score.best_out = best_out;
+  score.best_idx = best_idx;
+
   LOG_INFO("Image prediction result is: id: %d", best_idx + 1);
+
+  output->result = &score;
+  output->len = sizeof(score);
   return result;
 }
diff --git a/samples/quant_model/mobilenet_v1.h b/samples/quant_model/mobilenet_v1.h
new file mode 100644
index 0000000..5277547
--- /dev/null
+++ b/samples/quant_model/mobilenet_v1.h
@@ -0,0 +1,11 @@
+#ifndef SAMPLES_MOBILENETV1_H
+#define SAMPLES_MOBILENETV1_H
+
+#include <stdint.h>
+
+typedef struct {
+    int best_idx;
+    int best_out;
+} MobilenetV1Output;
+
+#endif
diff --git a/samples/quant_model/mobilenet_v2.c b/samples/quant_model/mobilenet_v2.c
index 7451a58..612ce65 100644
--- a/samples/quant_model/mobilenet_v2.c
+++ b/samples/quant_model/mobilenet_v2.c
@@ -44,16 +44,8 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
-  iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
-  return result;
+iree_status_t process_output(const MlModel *model,
+                              iree_hal_buffer_mapping_t *buffers,
+                              MlOutput *output) {
+  return iree_ok_status();
 }
diff --git a/samples/quant_model/person_detection.c b/samples/quant_model/person_detection.c
index 22575dc..0f8e4ea 100644
--- a/samples/quant_model/person_detection.c
+++ b/samples/quant_model/person_detection.c
@@ -7,6 +7,7 @@
 #include "iree/base/api.h"
 #include "iree/hal/api.h"
 #include "samples/util/util.h"
+#include "person_detection.h"
 
 // Compiled module embedded here to avoid file IO:
 #include "samples/quant_model/person_detection_bytecode_module_dylib_c.h"
@@ -26,6 +27,8 @@
     .model_name = "person_detection_quant",
 };
 
+PersonDetectionOutput detection;
+
 const iree_const_byte_span_t load_bytecode_module_data() {
   const struct iree_file_toc_t *module_file_toc =
       samples_quant_model_person_detection_bytecode_module_dylib_create();
@@ -41,18 +44,19 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
+iree_status_t process_output(const MlModel *model,
+                              iree_hal_buffer_mapping_t *buffers,
+                              MlOutput *output) {
   iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
-  int8_t *data = (int8_t *)mapped_memory->contents.data;
-  LOG_INFO("Output: Non-person Score: %d; Person Score: %d", data[0], data[1]);
+  int8_t *data = (int8_t *)buffers[0].contents.data;
+  detection.non_person_score = data[0];
+  detection.person_score = data[1];
+
+  LOG_INFO("Output: Non-person Score: %d; Person Score: %d",
+            detection.non_person_score,
+            detection.person_score);
+  output->result = &detection;
+  output->len = sizeof(detection);
+
   return result;
 }
diff --git a/samples/quant_model/person_detection.h b/samples/quant_model/person_detection.h
new file mode 100644
index 0000000..2e40ee1
--- /dev/null
+++ b/samples/quant_model/person_detection.h
@@ -0,0 +1,11 @@
+#ifndef SAMPLES_PERSON_DETECTION_H
+#define SAMPLES_PERSON_DETECTION_H
+
+#include <stdint.h>
+
+typedef struct {
+    int8_t non_person_score;
+    int8_t person_score;
+} PersonDetectionOutput;
+
+#endif
diff --git a/samples/quant_model/scenenet_v2.c b/samples/quant_model/scenenet_v2.c
index dbbf7fc..931c389 100644
--- a/samples/quant_model/scenenet_v2.c
+++ b/samples/quant_model/scenenet_v2.c
@@ -44,16 +44,8 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
-  iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
-  return result;
+iree_status_t process_output(const MlModel *model,
+                              iree_hal_buffer_mapping_t *buffers,
+                              MlOutput *output) {
+  return iree_ok_status();
 }
diff --git a/samples/quant_model/semantic_lift.c b/samples/quant_model/semantic_lift.c
index 02f6cb6..2ce1e95 100644
--- a/samples/quant_model/semantic_lift.c
+++ b/samples/quant_model/semantic_lift.c
@@ -44,16 +44,8 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
-  iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
-  return result;
+iree_status_t process_output(const MlModel *model,
+                              iree_hal_buffer_mapping_t *buffers,
+                              MlOutput *output) {
+  return iree_ok_status();
 }
diff --git a/samples/quant_model/voice_commands.c b/samples/quant_model/voice_commands.c
index 3444af0..4fdbbee 100644
--- a/samples/quant_model/voice_commands.c
+++ b/samples/quant_model/voice_commands.c
@@ -44,16 +44,8 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
-  iree_status_t result = iree_ok_status();
-  if (index_output > model->num_output ||
-      mapped_memory->contents.data_length / model->output_size_bytes !=
-          model->output_length[index_output]) {
-    result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
-  }
-  LOG_INFO("Output #%d data length: %d", index_output,
-           mapped_memory->contents.data_length / model->output_size_bytes);
-  return result;
+iree_status_t process_output(const MlModel *model,
+                              iree_hal_buffer_mapping_t *buffers,
+                              MlOutput *output) {
+  return iree_ok_status();
 }
diff --git a/samples/simple_vec_mul/float_vec.c b/samples/simple_vec_mul/float_vec.c
index 197684f..15c50e2 100644
--- a/samples/simple_vec_mul/float_vec.c
+++ b/samples/simple_vec_mul/float_vec.c
@@ -60,13 +60,13 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
+iree_status_t process_output(const MlModel *model,
+                  iree_hal_buffer_mapping_t *buffers,
+                  MlOutput *output) {
   iree_status_t result = iree_ok_status();
-  for (int i = 0; i < mapped_memory->contents.data_length / sizeof(float);
+  for (int i = 0; i < buffers[0].contents.data_length / sizeof(float);
        ++i) {
-    if (((const float *)mapped_memory->contents.data)[i] != i * i / 8.0f) {
+    if (((const float *)buffers[0].contents.data)[i] != i * i / 8.0f) {
       result = iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches");
       break;
     }
diff --git a/samples/simple_vec_mul/int_vec.c b/samples/simple_vec_mul/int_vec.c
index 824a031..6abf697 100644
--- a/samples/simple_vec_mul/int_vec.c
+++ b/samples/simple_vec_mul/int_vec.c
@@ -60,13 +60,13 @@
   return result;
 }
 
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output) {
+iree_status_t process_output(const MlModel *model,
+                  iree_hal_buffer_mapping_t *buffers,
+                  MlOutput *output) {
   iree_status_t result = iree_ok_status();
-  for (int i = 0; i < mapped_memory->contents.data_length / sizeof(int32_t);
+  for (int i = 0; i < buffers[0].contents.data_length / sizeof(int32_t);
        ++i) {
-    if (((const int32_t *)mapped_memory->contents.data)[i] != (i >> 1) * i) {
+    if (((const int32_t *)buffers[0].contents.data)[i] != (i >> 1) * i) {
       result = iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches");
       break;
     }
diff --git a/samples/util/model_api.h b/samples/util/model_api.h
index e4a4f23..c5671c7 100644
--- a/samples/util/model_api.h
+++ b/samples/util/model_api.h
@@ -29,6 +29,11 @@
   char model_name[];
 } MlModel;
 
+typedef struct {
+  void* result;
+  uint32_t len;
+} MlOutput;
+
 // Load the VM bytecode module from the embedded c library into memory.
 const iree_const_byte_span_t load_bytecode_module_data();
 
@@ -44,10 +49,11 @@
 // randomly generated stream, or a pointer from the sensor/ISP output.
 iree_status_t load_input_data(const MlModel *model, void **buffer);
 
-// Check the ML execution output, and prepare the final data to be sent to the
-// host with post processing. The final format is model dependent.
-iree_status_t check_output_data(const MlModel *model,
-                                iree_hal_buffer_mapping_t *mapped_memory,
-                                int index_output);
+// Process the ML execution output into the final data to be sent to the
+// host. The final format is model dependent, so the address and size
+// are returned via `output.`
+iree_status_t process_output(const MlModel *model,
+                              iree_hal_buffer_mapping_t *buffers,
+                              MlOutput *output);
 
 #endif  // SW_VEC_IREE_SAMPLES_UTIL_MODEL_API_H_
diff --git a/samples/util/util.c b/samples/util/util.c
index cb674a0..b347358 100644
--- a/samples/util/util.c
+++ b/samples/util/util.c
@@ -138,6 +138,8 @@
                             iree_allocator_system());
   }
 
+  // Validate output and gather buffers.
+  iree_hal_buffer_mapping_t mapped_memories[MAX_MODEL_OUTPUTS] = {{0}};
   for (int index_output = 0; index_output < model->num_output; index_output++) {
     iree_hal_buffer_view_t *ret_buffer_view = NULL;
     if (iree_status_is_ok(result)) {
@@ -149,20 +151,34 @@
                                   "can't find return buffer view");
       }
     }
-    // Read back the results and ensure we got the right values.
-    iree_hal_buffer_mapping_t mapped_memory;
     if (iree_status_is_ok(result)) {
       result = iree_hal_buffer_map_range(
           iree_hal_buffer_view_buffer(ret_buffer_view),
           IREE_HAL_MAPPING_MODE_SCOPED, IREE_HAL_MEMORY_ACCESS_READ, 0,
-          IREE_WHOLE_BUFFER, &mapped_memory);
+          IREE_WHOLE_BUFFER, &mapped_memories[index_output]);
     }
+
     if (iree_status_is_ok(result)) {
-      result = check_output_data(model, &mapped_memory, index_output);
-      iree_hal_buffer_unmap_range(&mapped_memory);
+      if (index_output > model->num_output ||
+            mapped_memories[index_output].contents.data_length / model->output_size_bytes !=
+            model->output_length[index_output]) {
+        result = iree_make_status(IREE_STATUS_UNKNOWN, "output length mismatches");
+      }
     }
   }
 
+  // Post-process memory into model output.
+  if (iree_status_is_ok(result)) {
+    MlOutput output = {.result = NULL, .len = 0};
+    result = process_output(model, mapped_memories, &output);
+    // TODO(jesionowski): Populate CSRs with `output` after validating result.
+  }
+
+  for (int index_output = 0; index_output < model->num_output; index_output++) {
+    if (mapped_memories[index_output].contents.data != NULL) {
+      iree_hal_buffer_unmap_range(&mapped_memories[index_output]);
+    }
+  }
   iree_vm_list_release(inputs);
   iree_vm_list_release(outputs);
   for (int i = 0; i < model->num_input; ++i) {