/*
 * Copyright 2023 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// SSD box decoding and extracting

#include "ssd_postprocess/box.h"

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

static SsdParams params = {
    .num_layers = 4,
    .num_boxes = 1602,
    .input_height = 320,
    .input_width = 320,
    .global_scales = {10, 10, 5, 5},  // y, x, h, w
    .box_zero_points = {115, 129, 125, 119},
    .box_scales = {0.0813235, 0.0786732, 0.0687513, 0.0522251},
    .score_zero_points = {211, 195, 200, 225},
    .score_scales = {0.177373, 0.121247, 0.100491, 0.0550178},
    .score_threshold = 0.5,
    .anchors_per_cell = 3,
    .anchor_base_size = {24.0, 32.0, 40.0, 48.0, 64.0, 80.0, 96.0, 128.0, 160.0,
                         192.0, 256.0, 320.0},
    .anchor_stride = {16, 32, 64, 128}};

// Set SSD parameters
void set_params(SsdParams* params_in) { params = *params_in; }

static inline float dequantize(int val, int zero_point, float scale) {
  return scale * (val - zero_point);
}

static inline float sigmoid(float val) { return 1.0 / (1.0 + expf(-val)); }

// Generate model anchors
// layer0: 20 * 20 * 3 = 1200
// layer1: 10 * 10 * 3 = 300
// layer2:   5 * 5 * 3 = 75
// layer3:   3 * 3 * 3 = 27
// total sum:            1602
static void generate_anchors(BoxCenterEncode* anchors) {
  int idx = 0;
  for (int layer = 0; layer < params.num_layers; ++layer) {
    int height_size = (params.input_height + params.anchor_stride[layer] - 1) /
                      params.anchor_stride[layer];
    int width_size = (params.input_width + params.anchor_stride[layer] - 1) /
                     params.anchor_stride[layer];
    for (int h = 0; h < height_size; h++) {
      for (int w = 0; w < width_size; w++) {
        for (int base = 0; base < params.anchors_per_cell; ++base) {
          anchors[idx].y =
              (float)params.anchor_stride[layer] * h / params.input_height;
          anchors[idx].x =
              (float)params.anchor_stride[layer] * w / params.input_width;
          anchors[idx].h =
              params.anchor_base_size[layer * params.anchors_per_cell + base] /
              params.input_height;
          anchors[idx].w =
              params.anchor_base_size[layer * params.anchors_per_cell + base] /
              params.input_width;
          idx++;
        }
      }
    }
  }
}

// Decode boxes (with score) from model inference outputs
// The locations channel dim is 16 x 3.
// Each 16 is composed of (4 box coordinates + 6 * 2 landmarks coordinates).
// We need only the first 4 box coordinates - so want to keep only indexes:
//  0, 1, 2, 3
// 16,17,18,19
// 32,33,34,35
static void decode_boxes(uint8_t** model_out, BoxCenterEncode* boxes) {
  const int num_coordinates = 16;
  int box_idx = 0;
  for (int layer = 0; layer < params.num_layers; layer++) {
    int height_size = (params.input_height + params.anchor_stride[layer] - 1) /
                      params.anchor_stride[layer];
    int width_size = (params.input_width + params.anchor_stride[layer] - 1) /
                     params.anchor_stride[layer];
    // Boxes at even indicees; scores at odd indices
    uint8_t* boxes_out = model_out[2 * layer];
    uint8_t* scores_out = model_out[2 * layer + 1];
    for (int i = 0; i < height_size * width_size; i++) {
      for (int j = 0; j < params.anchors_per_cell; j++) {
        int score_idx = i * params.anchors_per_cell + j;
        int chan_idx = num_coordinates * score_idx;
        // dequantize box
        boxes[box_idx].y =
            dequantize(boxes_out[chan_idx], params.box_zero_points[layer],
                       params.box_scales[layer]);
        boxes[box_idx].x =
            dequantize(boxes_out[chan_idx + 1], params.box_zero_points[layer],
                       params.box_scales[layer]);
        boxes[box_idx].h =
            dequantize(boxes_out[chan_idx + 2], params.box_zero_points[layer],
                       params.box_scales[layer]);
        boxes[box_idx].w =
            dequantize(boxes_out[chan_idx + 3], params.box_zero_points[layer],
                       params.box_scales[layer]);
        // dequantize score
        float dequant_score =
            dequantize(scores_out[score_idx], params.score_zero_points[layer],
                       params.score_scales[layer]);
        boxes[box_idx].score = sigmoid(dequant_score);
        box_idx++;
      }
    }
  }
}

// Convert box from center encoding to corner encoding format
static void convert_box(const BoxCenterEncode* box_in, BoxCenterEncode* anchor,
                        BoxCornerEncode* box_out) {
  float y_center = box_in->y / params.global_scales[0] * anchor->h + anchor->y;
  float x_center = box_in->x / params.global_scales[1] * anchor->w + anchor->x;
  float half_h = 0.5 * expf(box_in->h / params.global_scales[2]) * anchor->h;
  float half_w = 0.5 * expf(box_in->w / params.global_scales[3]) * anchor->w;

  box_out->ymin = y_center - half_h;
  box_out->xmin = x_center - half_w;
  box_out->ymax = y_center + half_h;
  box_out->xmax = x_center + half_w;
  box_out->score = box_in->score;
}

// Detect boxes by score thresholding
static void detect_boxes(const BoxCenterEncode* boxes_in,
                         BoxCenterEncode* anchors, Boxes* boxes_out) {
  int num_detected_boxes = 0;
  for (int i = 0; i < params.num_boxes; ++i) {
    if (boxes_in[i].score > params.score_threshold) {
      num_detected_boxes++;
    }
  }
  if (!(boxes_out->box)) {
    boxes_out->box =
        (BoxCornerEncode*)malloc(sizeof(BoxCornerEncode) * num_detected_boxes);
  }

  num_detected_boxes = 0;
  for (int i = 0; i < params.num_boxes; ++i) {
    if (boxes_in[i].score > params.score_threshold) {
      convert_box(&(boxes_in[i]), &(anchors[i]),
                  &(boxes_out->box[num_detected_boxes]));
      num_detected_boxes++;
    }
  }
  boxes_out->num_boxes = num_detected_boxes;
}

// Decode and extract detected boxes
void get_detected_boxes(uint8_t** model_out, Boxes* boxes_out) {
  BoxCenterEncode* boxes_in =
      (BoxCenterEncode*)malloc(sizeof(BoxCenterEncode) * params.num_boxes);
  BoxCenterEncode* anchors =
      (BoxCenterEncode*)malloc(sizeof(BoxCenterEncode) * params.num_boxes);

  generate_anchors(anchors);

  decode_boxes(model_out, boxes_in);

  detect_boxes(boxes_in, anchors, boxes_out);

  free(anchors);
  free(boxes_in);
}
