Add realistic input for daredevil Add realstic input for daredevil. Input is from: google3/knowledge/cerebra/parrot/port/android/platform/inference/daredevil/model_utils/test_data/golden_whistle_spectrogram Returned output is correct (Whistling). Change-Id: I325a89d1f9899285c35e7eee93448d43cc72728b
diff --git a/build_tools/gen_mlmodel_input.py b/build_tools/gen_mlmodel_input.py index a79afc2..cc9a352 100755 --- a/build_tools/gen_mlmodel_input.py +++ b/build_tools/gen_mlmodel_input.py
@@ -55,8 +55,8 @@ input_ext = os.path.splitext(input_name)[1] if (not input_ext) or (input_ext == '.bin'): with open(input_name, mode='rb') as f: - input = np.fromfile(f, dtype=np.uint8 if is_quant else np.float32).reshape( - np.prod(input_shape)) + input = np.fromfile(f, dtype=np.uint8 if is_quant else np.float32) + input = input[:np.prod(input_shape)].reshape(np.prod(input_shape)) else: resized_img = Image.open(input_name).resize( (input_shape[1], input_shape[2]))
diff --git a/samples/quant_model/CMakeLists.txt b/samples/quant_model/CMakeLists.txt index d7316db..95f2456 100644 --- a/samples/quant_model/CMakeLists.txt +++ b/samples/quant_model/CMakeLists.txt
@@ -163,6 +163,16 @@ iree_model_input( NAME + daredevil_quant_input + SHAPE + "1, 96, 64" + SRC + "$ENV{ROOTDIR}/ml/ml-models/test_data/golden_whistle_spectrogram" + QUANT +) + +iree_model_input( + NAME fssd_quant_input SHAPE "1, 480, 640, 1" @@ -328,6 +338,7 @@ DEPS ::daredevil_bytecode_module_static ::daredevil_bytecode_module_static_c + ::daredevil_quant_input_c iree::vm::bytecode_module samples::util::util LINKOPTS @@ -343,6 +354,7 @@ DEPS ::daredevil_c_module_static_c ::daredevil_c_module_static_emitc + ::daredevil_quant_input_c samples::util::util LINKOPTS "LINKER:--defsym=__itcm_length__=1M"
diff --git a/samples/quant_model/daredevil.c b/samples/quant_model/daredevil.c index cb62305..178dc7e 100644 --- a/samples/quant_model/daredevil.c +++ b/samples/quant_model/daredevil.c
@@ -20,11 +20,13 @@ #include <springbok.h> +#include "daredevil.h" #include "iree/base/api.h" #include "iree/hal/api.h" #include "samples/util/util.h" // Compiled module embedded here to avoid file IO: +#include "samples/quant_model/daredevil_quant_input_c.h" #if !defined(BUILD_EMITC) #include "samples/quant_model/daredevil_bytecode_module_static.h" #include "samples/quant_model/daredevil_bytecode_module_static_c.h" @@ -47,6 +49,8 @@ .model_name = "daredevil_quant", }; +DaredevilOutput score; + iree_status_t create_module(iree_vm_module_t **module) { #if !defined(BUILD_EMITC) const struct iree_file_toc_t *module_file_toc = @@ -69,22 +73,37 @@ iree_status_t load_input_data(const MlModel *model, void **buffer, iree_const_byte_span_t **byte_span) { - iree_status_t result = alloc_input_buffer(model, buffer); - // Populate initial value - srand(3689964); - if (iree_status_is_ok(result)) { - for (int i = 0; i < model->input_length[0]; ++i) { - ((uint8_t *)*buffer)[i] = (uint8_t)rand(); - } - } byte_span[0] = malloc(sizeof(iree_const_byte_span_t)); *byte_span[0] = iree_make_const_byte_span( - buffer[0], model->input_size_bytes[0] * model->input_length[0]); - return result; + daredevil_quant_input, + model->input_size_bytes[0] * model->input_length[0]); + + + + return iree_ok_status(); } iree_status_t process_output(const MlModel *model, iree_hal_buffer_mapping_t *buffers, MlOutput *output) { - return iree_ok_status(); + iree_status_t result = iree_ok_status(); + // find the label index with best prediction + int best_out = 0; + int best_idx = -1; + for (int i = 0; i < model->output_length[0]; ++i) { + uint8_t out = ((uint8_t *)buffers[0].contents.data)[i]; + if (out > best_out) { + best_out = out; + best_idx = i; + } + } + score.best_out = best_out; + score.best_idx = best_idx; + + LOG_INFO("Best prediction result is: id: %d", best_idx + 1); + LOG_INFO("Id # 41 is Whistling"); + + output->result = &score; + output->len = sizeof(score); + return result; }
diff --git a/samples/quant_model/daredevil.h b/samples/quant_model/daredevil.h new file mode 100644 index 0000000..fdaf5c8 --- /dev/null +++ b/samples/quant_model/daredevil.h
@@ -0,0 +1,27 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SAMPLES_QUANT_MODEL_DAREDEVIL_H_ +#define SAMPLES_QUANT_MODEL_DAREDEVIL_H_ + +#include <stdint.h> + +typedef struct { + int best_idx; + int best_out; +} DaredevilOutput; + +#endif
diff --git a/samples/quant_model/daredevil_bytecode_static_test.txt b/samples/quant_model/daredevil_bytecode_static_test.txt new file mode 100644 index 0000000..34c3e3f --- /dev/null +++ b/samples/quant_model/daredevil_bytecode_static_test.txt
@@ -0,0 +1,4 @@ +// RUN: ${TEST_RUNNER_CMD} ${OUT}/springbok_iree/samples/quant_model/daredevil_bytecode_static 2>&1 | tee %t +// RUN: cat %t | FileCheck %s +// CHECK: {{Best prediction result is: id: 41}} +// REQUIRES: internal
diff --git a/samples/quant_model/daredevil_emitc_static_test.txt b/samples/quant_model/daredevil_emitc_static_test.txt new file mode 100644 index 0000000..25fc487 --- /dev/null +++ b/samples/quant_model/daredevil_emitc_static_test.txt
@@ -0,0 +1,4 @@ +// RUN: ${TEST_RUNNER_CMD} ${OUT}/springbok_iree/samples/quant_model/daredevil_emitc_static 2>&1 | tee %t +// RUN: cat %t | FileCheck %s +// CHECK: {{Best prediction result is: id: 41}} +// REQUIRES: internal
diff --git a/samples/quant_model/daredevil_test.txt b/samples/quant_model/daredevil_test.txt deleted file mode 100644 index b710964..0000000 --- a/samples/quant_model/daredevil_test.txt +++ /dev/null
@@ -1,3 +0,0 @@ -// RUN: ${TEST_RUNNER_CMD} ${OUT}/springbok_iree/samples/quant_model/daredevil_bytecode_static -// RUN: ${TEST_RUNNER_CMD} ${OUT}/springbok_iree/samples/quant_model/daredevil_emitc_static -// REQUIRES: internal