Add realistic input for daredevil
Add realstic input for daredevil. Input is from: google3/knowledge/cerebra/parrot/port/android/platform/inference/daredevil/model_utils/test_data/golden_whistle_spectrogram
Returned output is correct (Whistling).
Change-Id: I325a89d1f9899285c35e7eee93448d43cc72728b
diff --git a/build_tools/gen_mlmodel_input.py b/build_tools/gen_mlmodel_input.py
index a79afc2..cc9a352 100755
--- a/build_tools/gen_mlmodel_input.py
+++ b/build_tools/gen_mlmodel_input.py
@@ -55,8 +55,8 @@
input_ext = os.path.splitext(input_name)[1]
if (not input_ext) or (input_ext == '.bin'):
with open(input_name, mode='rb') as f:
- input = np.fromfile(f, dtype=np.uint8 if is_quant else np.float32).reshape(
- np.prod(input_shape))
+ input = np.fromfile(f, dtype=np.uint8 if is_quant else np.float32)
+ input = input[:np.prod(input_shape)].reshape(np.prod(input_shape))
else:
resized_img = Image.open(input_name).resize(
(input_shape[1], input_shape[2]))
diff --git a/samples/quant_model/CMakeLists.txt b/samples/quant_model/CMakeLists.txt
index d7316db..95f2456 100644
--- a/samples/quant_model/CMakeLists.txt
+++ b/samples/quant_model/CMakeLists.txt
@@ -163,6 +163,16 @@
iree_model_input(
NAME
+ daredevil_quant_input
+ SHAPE
+ "1, 96, 64"
+ SRC
+ "$ENV{ROOTDIR}/ml/ml-models/test_data/golden_whistle_spectrogram"
+ QUANT
+)
+
+iree_model_input(
+ NAME
fssd_quant_input
SHAPE
"1, 480, 640, 1"
@@ -328,6 +338,7 @@
DEPS
::daredevil_bytecode_module_static
::daredevil_bytecode_module_static_c
+ ::daredevil_quant_input_c
iree::vm::bytecode_module
samples::util::util
LINKOPTS
@@ -343,6 +354,7 @@
DEPS
::daredevil_c_module_static_c
::daredevil_c_module_static_emitc
+ ::daredevil_quant_input_c
samples::util::util
LINKOPTS
"LINKER:--defsym=__itcm_length__=1M"
diff --git a/samples/quant_model/daredevil.c b/samples/quant_model/daredevil.c
index cb62305..178dc7e 100644
--- a/samples/quant_model/daredevil.c
+++ b/samples/quant_model/daredevil.c
@@ -20,11 +20,13 @@
#include <springbok.h>
+#include "daredevil.h"
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "samples/util/util.h"
// Compiled module embedded here to avoid file IO:
+#include "samples/quant_model/daredevil_quant_input_c.h"
#if !defined(BUILD_EMITC)
#include "samples/quant_model/daredevil_bytecode_module_static.h"
#include "samples/quant_model/daredevil_bytecode_module_static_c.h"
@@ -47,6 +49,8 @@
.model_name = "daredevil_quant",
};
+DaredevilOutput score;
+
iree_status_t create_module(iree_vm_module_t **module) {
#if !defined(BUILD_EMITC)
const struct iree_file_toc_t *module_file_toc =
@@ -69,22 +73,37 @@
iree_status_t load_input_data(const MlModel *model, void **buffer,
iree_const_byte_span_t **byte_span) {
- iree_status_t result = alloc_input_buffer(model, buffer);
- // Populate initial value
- srand(3689964);
- if (iree_status_is_ok(result)) {
- for (int i = 0; i < model->input_length[0]; ++i) {
- ((uint8_t *)*buffer)[i] = (uint8_t)rand();
- }
- }
byte_span[0] = malloc(sizeof(iree_const_byte_span_t));
*byte_span[0] = iree_make_const_byte_span(
- buffer[0], model->input_size_bytes[0] * model->input_length[0]);
- return result;
+ daredevil_quant_input,
+ model->input_size_bytes[0] * model->input_length[0]);
+
+
+
+ return iree_ok_status();
}
iree_status_t process_output(const MlModel *model,
iree_hal_buffer_mapping_t *buffers,
MlOutput *output) {
- return iree_ok_status();
+ iree_status_t result = iree_ok_status();
+ // find the label index with best prediction
+ int best_out = 0;
+ int best_idx = -1;
+ for (int i = 0; i < model->output_length[0]; ++i) {
+ uint8_t out = ((uint8_t *)buffers[0].contents.data)[i];
+ if (out > best_out) {
+ best_out = out;
+ best_idx = i;
+ }
+ }
+ score.best_out = best_out;
+ score.best_idx = best_idx;
+
+ LOG_INFO("Best prediction result is: id: %d", best_idx + 1);
+ LOG_INFO("Id # 41 is Whistling");
+
+ output->result = &score;
+ output->len = sizeof(score);
+ return result;
}
diff --git a/samples/quant_model/daredevil.h b/samples/quant_model/daredevil.h
new file mode 100644
index 0000000..fdaf5c8
--- /dev/null
+++ b/samples/quant_model/daredevil.h
@@ -0,0 +1,27 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SAMPLES_QUANT_MODEL_DAREDEVIL_H_
+#define SAMPLES_QUANT_MODEL_DAREDEVIL_H_
+
+#include <stdint.h>
+
+typedef struct {
+ int best_idx;
+ int best_out;
+} DaredevilOutput;
+
+#endif
diff --git a/samples/quant_model/daredevil_bytecode_static_test.txt b/samples/quant_model/daredevil_bytecode_static_test.txt
new file mode 100644
index 0000000..34c3e3f
--- /dev/null
+++ b/samples/quant_model/daredevil_bytecode_static_test.txt
@@ -0,0 +1,4 @@
+// RUN: ${TEST_RUNNER_CMD} ${OUT}/springbok_iree/samples/quant_model/daredevil_bytecode_static 2>&1 | tee %t
+// RUN: cat %t | FileCheck %s
+// CHECK: {{Best prediction result is: id: 41}}
+// REQUIRES: internal
diff --git a/samples/quant_model/daredevil_emitc_static_test.txt b/samples/quant_model/daredevil_emitc_static_test.txt
new file mode 100644
index 0000000..25fc487
--- /dev/null
+++ b/samples/quant_model/daredevil_emitc_static_test.txt
@@ -0,0 +1,4 @@
+// RUN: ${TEST_RUNNER_CMD} ${OUT}/springbok_iree/samples/quant_model/daredevil_emitc_static 2>&1 | tee %t
+// RUN: cat %t | FileCheck %s
+// CHECK: {{Best prediction result is: id: 41}}
+// REQUIRES: internal
diff --git a/samples/quant_model/daredevil_test.txt b/samples/quant_model/daredevil_test.txt
deleted file mode 100644
index b710964..0000000
--- a/samples/quant_model/daredevil_test.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-// RUN: ${TEST_RUNNER_CMD} ${OUT}/springbok_iree/samples/quant_model/daredevil_bytecode_static
-// RUN: ${TEST_RUNNER_CMD} ${OUT}/springbok_iree/samples/quant_model/daredevil_emitc_static
-// REQUIRES: internal