/*
 * Copyright 2023 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef RISP4ML_COMMON_UTILS_H_
#define RISP4ML_COMMON_UTILS_H_

#include <math.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>

#include "risp4ml/common/constants.h"

// Return the RAW color channel index at position (x, y) for a given Bayer
// pattern
// The Bayer pattern code defines which color is top left in the quad:
// 0: +---+---+ 1: +---+---+ 2: +---+---+ 3: +---+---+
//    | R | Gr|    | Gr| R |    | Gb| B |    | B | Gb|
//    +---+---+    +---+---+    +---+---+    +---+---+
//    | Gb| B |    | B | Gb|    | R | Gr|    | Gr| R |
//    +---+---+    +---+---+    +---+---+    +---+---+
// pattern 0 is base pattern and other patterns are shifted versions of the
// base
BayerIndex GetBayerIndex(BayerPattern bayerType, uint16_t x, uint16_t y);

// Get the corresponding index of x in bayer images for when the index is out
// of bounds and mirrored across the boundary.
int BayerMirrorBoundary(int x, int size);

static inline uint32_t Clamp(uint32_t value, uint32_t low, uint32_t high) {
  return value < low ? low : (value > high ? high : value);
}

static inline uint16_t SubUnsignedZeroClamp(uint16_t lhs, uint16_t rhs) {
  return rhs < lhs ? lhs - rhs : 0;
}

// Count the number of consecutive zeros from LHS in N msbs of the number
// represented using BPP bits
static inline int ClzMsb(int in, int BPP, int N) {
  int lz = 0;
  while (lz < N && (in & (1 << (BPP - lz - 1))) == 0) {
    ++lz;
  }
  return lz;
}

static inline float Roundf(float x) {
  int d = x < 0 ? x - 0.5 : x + 0.5;
  return (float)d;  // NOLINT(readability/casting)
}

// This function converts floating point value `x` to fixed point with the
// specified `integer_bit`, `frac_bit`, and `is_signed` flag.
// TODO(alexkaplan): Detect overflow/underflow.
static inline int FloatToFixedPoint(float x, int integer_bit, int frac_bit,
                                    bool is_signed) {
  float output_as_float = Roundf(x * (1 << frac_bit));
  float min_value = 0;
  float max_value = (1 << (frac_bit + integer_bit)) - 1;

  if (is_signed) {
    min_value = -(1 << (frac_bit + integer_bit - 1));
    max_value = (1 << (frac_bit + integer_bit - 1)) - 1;
  }

  // Clamp to the allowed range.
  if (output_as_float < min_value) {
    return (int)min_value;  // NOLINT(readability/casting)
  } else if (output_as_float > max_value) {
    return (int)max_value;  // NOLINT(readability/casting)
  }
  return (int)output_as_float;  // NOLINT(readability/casting)
}

// Helper function for fixed point rounding of values.
static inline int Round(int value, int right_shift) {
  int carry = right_shift == 0 ? 0 : (value >> (right_shift - 1)) & 1;
  return (value >> right_shift) + carry;
}

// Helper function for linearly interpolating 2 values. When weight equals 0,
// output = val0. When weight equals 1.0 (when represented in floating point),
// output = val1.
static inline int Lerp(int val0, int val1, int weight, int weight_precision) {
  return val0 + Round((val1 - val0) * weight, weight_precision);
}

#endif  // RISP4ML_COMMON_UTILS_H_
