blob: 80ce8f181673c7e78ba269b59ecbb6db2732bb41 [file] [log] [blame]
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// NMS (Non-Maximum Suppression) algorithm
#include "ssd_postprocess/nms.h"
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifndef MIN
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
#endif
#ifndef MAX
#define MAX(a, b) (((a) > (b)) ? (a) : (b))
#endif
// compute area of one box
static float compute_box_area(const BoxCornerEncode* box) {
const float width = box->xmax - box->xmin;
const float height = box->ymax - box->ymin;
return MAX(0.0, width * height);
}
// compute IOU (intersection over union) of two boxes
static float compute_two_boxes_iou(const BoxCornerEncode* box1,
const BoxCornerEncode* box2) {
const float area1 = compute_box_area(box1);
const float area2 = compute_box_area(box2);
if (area1 <= 0 || area2 <= 0) return 0.0;
BoxCornerEncode intersection_box = {.ymin = MAX(box1->ymin, box2->ymin),
.xmin = MAX(box1->xmin, box2->xmin),
.ymax = MIN(box1->ymax, box2->ymax),
.xmax = MIN(box1->xmax, box2->xmax)};
float intersection_area = compute_box_area(&intersection_box);
return intersection_area / (area1 + area2 - intersection_area);
}
// comparator for qsort
static int comparator(const void* p, const void* q) {
float x = ((BoxCornerEncode*)p)->score;
float y = ((BoxCornerEncode*)q)->score;
return (y > x) - (y < x);
}
// Perform non-maximum suppression algorithm to remove "similar" bounding boxes
void nms(Boxes* boxes_in, Boxes* boxes_out, const int max_boxes,
const float iou_threshold) {
int num_boxes = boxes_in->num_boxes;
uint8_t* is_suppressed = (uint8_t*)malloc(num_boxes * sizeof(uint8_t));
memset(is_suppressed, 0, num_boxes * sizeof(uint8_t));
// quick sort from greatest to smallest
qsort(boxes_in->box, num_boxes, sizeof(BoxCornerEncode), comparator);
for (int i = 0; i < num_boxes; i++) {
if (!is_suppressed[i]) {
for (int j = i + 1; j < num_boxes; j++) {
if (!is_suppressed[j]) {
if (compute_two_boxes_iou(&(boxes_in->box[i]), &(boxes_in->box[j])) >
iou_threshold) {
is_suppressed[j] = 1;
}
}
}
}
}
int ind_out = 0;
for (int i = 0; i < num_boxes; i++) {
if (ind_out >= max_boxes) break;
if (!is_suppressed[i]) {
boxes_out->box[ind_out++] = boxes_in->box[i];
}
}
boxes_out->num_boxes = ind_out;
free(is_suppressed);
}