blob: a5b7e6e51e1c0fb431751878bae630c55477e8a5 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -----------------------------------------------------------------------------
//
// Misc. common math functions
//
// Authors: vrabaud (vincent.rabaud@google.com)
//
#ifndef WP2_DSP_MATH_H_
#define WP2_DSP_MATH_H_
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <limits>
#include <type_traits>
#include "src/dsp/dsp.h"
const uint32_t kLogLookupIdxMax = 256;
extern const float kWP2Log2Table[kLogLookupIdxMax];
extern const float kWP2SLog2m1Table[kLogLookupIdxMax];
typedef float (*WP2Log2SlowFunc)(uint32_t v);
extern WP2Log2SlowFunc WP2Log2Slow;
extern WP2Log2SlowFunc WP2SLog2m1Slow;
// Fast calculation of log2(v) for integer input.
static inline float WP2Log2(uint32_t v) {
return (v < kLogLookupIdxMax) ? kWP2Log2Table[v] : WP2Log2Slow(v);
}
// Fast calculation of v * (1 - log2(v)) for integer input.
static inline float WP2SLog2m1(uint32_t v) {
return (v < kLogLookupIdxMax) ? kWP2SLog2m1Table[v] : WP2SLog2m1Slow(v);
}
// fast look-up version for small values
static inline float WP2Log2Fast(uint32_t v) {
assert(v < kLogLookupIdxMax);
return kWP2Log2Table[v];
}
// Returns (int)floor(log2(n)).
// use GNU builtins where available.
#if defined(__GNUC__)
static inline int WP2Log2Floor(uint32_t n) {
return (n == 0) ? 0 : (31 ^ __builtin_clz(n));
}
// counts the number of trailing zero
static inline int WP2Ctz(uint32_t n) { return __builtin_ctz(n); }
#elif defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))
#include <intrin.h>
#pragma intrinsic(_BitScanReverse)
#pragma intrinsic(_BitScanForward)
static inline int WP2Log2Floor(uint32_t n) {
if (n == 0) return 0;
unsigned long first_set_bit; // NOLINT (runtime/int)
_BitScanReverse(&first_set_bit, n);
return first_set_bit;
}
static inline int WP2Ctz(uint32_t n) {
unsigned long first_set_bit; // NOLINT (runtime/int)
_BitScanForward(&first_set_bit, n);
return first_set_bit;
}
#else // default: use the C-version.
#define WP2_NEED_LOG_TABLE_8BIT
extern const uint8_t WP2LogTable8bit[256 + 1];
static inline int WP2Log2Floor(uint32_t n) {
int log_value = 0;
while (n > 256) {
log_value += 8;
n >>= 8;
}
return log_value + WP2LogTable8bit[n];
}
static inline int WP2Ctz(uint32_t n) {
for (uint32_t i = 0; i < 32; ++i, n >>= 1) {
if (n & 1) return i;
}
return 32;
}
#endif
// To be used by constants and assertions.
// Returns floor(log2(n)).
static constexpr uint32_t WP2Log2Floor_k(uint64_t n) {
return (n < 2) ? 0
: (n < 4) ? 1
: (n < 8) ? 2
: (n < 16) ? 3
: (n < 32) ? 4
: (n < 64) ? 5
: (n < 128) ? 6
: (n < 256) ? 7
: (8u + WP2Log2Floor_k(n >> 8));
}
// Returns ceil(log2(n)).
static constexpr uint32_t WP2Log2Ceil_k(uint64_t n) {
return (n <= 1) ? 0 : (1 + WP2Log2Floor_k(n - 1));
}
namespace WP2 {
// Computes the Shannon entropy of a histogram.
typedef float (*ShannonEntropyFunc)(const uint32_t* histo, uint32_t length);
extern ShannonEntropyFunc ShannonEntropy;
} // namespace WP2
// Only use for small values, smaller than 0xffff.
typedef int32_t (*WP2InnerProductFunc)(const int16_t* const a,
const int16_t* const b, size_t size);
extern WP2InnerProductFunc WP2InnerProduct;
// Must be called before calling any of the above methods.
void WP2MathInit();
//------------------------------------------------------------------------------
namespace WP2 {
// Returns the closest value to 'v' in [min:max].
template <typename T>
static constexpr T Clamp(T v, T min, T max) {
return (v < min) ? min : (v > max) ? max : v;
}
// Returns the closest value to 'v' fitting in 'num_bits' (sign included).
template <typename T>
static inline T ClampToSigned(T v, uint32_t num_bits) {
static_assert(std::numeric_limits<T>::min() < 0, "T must be signed");
assert(num_bits > 0 && num_bits < sizeof(T) * 8);
const T limit = (T)1 << (num_bits - 1);
return Clamp<T>(v, -limit, limit - 1);
}
// Returns the closest value to 'a + b' in [min:max], for under/overflow check.
template <typename T>
static constexpr T SafeAdd(T a, T b, T min = std::numeric_limits<T>::min(),
T max = std::numeric_limits<T>::max()) {
// First condition should be optimized out for unsigned types.
return (b < 0) ? ((/*a+b >= min*/ a >= min - b) ? a + b : min)
: ((/*a+b <= max*/ a <= max - b) ? a + b : max);
}
// Returns the closest value to 'a - b' in [min:max], for under/overflow check.
template <typename T>
static constexpr T SafeSub(T a, T b, T min = std::numeric_limits<T>::min(),
T max = std::numeric_limits<T>::max()) {
// First condition should be optimized out for unsigned types.
return (b < 0) ? ((/*a-b <= max*/ a <= max + b) ? a - b : max)
: ((/*a-b >= min*/ a >= min + b) ? a - b : min);
}
//------------------------------------------------------------------------------
// Regular shifts with ignored Undefined Behavior Sanitizer warnings about the
// shift of negative values.
template <typename T>
inline WP2_UBSAN_IGNORE_UNDEF T RightShift(T v, uint32_t shift) {
return v >> shift;
}
template <typename T>
inline T RightShiftRound(T v, uint32_t shift) {
return RightShift(v + ((T)1 << shift >> 1), shift);
}
template <typename T>
inline WP2_UBSAN_IGNORE_UNDEF T LeftShift(T v, uint32_t shift) {
return v << shift;
}
// Returns 'v' shifted right or left by abs(from - to). Right shifting uses an
// extra bit during computation for rounding.
template <typename T>
inline T ChangePrecision(T v, uint32_t from, uint32_t to) {
return (from <= to) ? LeftShift<T>(v, to - from)
: RightShiftRound<T>(v, from - to);
}
// Returns ceil(a / b).
template <typename T>
static constexpr T DivCeil(T a, T b) {
return (a == 0) ? 0 : (a - 1) / b + 1;
}
// Returns round(a / b).
template <typename T>
static constexpr T DivRound(T a, T b) {
return ((a < 0) == (b < 0)) ? ((a + b / 2) / b) : ((a - b / 2) / b);
}
static constexpr uint32_t Pad(uint32_t dim, uint32_t pad) {
return (dim + pad - 1) & ~(pad - 1);
}
// 3x3 matrix multiplication such as 'a * b = c'.
// Make sure there is no overflow by keeping a margin of 3 bits in 'a' or 'b').
template <typename TypeA, typename TypeB, typename TypeC>
static void Multiply(const TypeA a[9], const TypeB b[9], TypeC c[9]) {
for (uint32_t j : {0, 1, 2}) {
for (uint32_t i : {0, 1, 2}) {
c[j * 3 + i] = (TypeC)a[j * 3 + 0] * b[0 * 3 + i] +
(TypeC)a[j * 3 + 1] * b[1 * 3 + i] +
(TypeC)a[j * 3 + 2] * b[2 * 3 + i];
}
}
}
// Vector and 3x3 matrix multiplication such as 'm * in = out'.
template <typename TypeIn, typename TypeMatrix, typename TypeOut>
inline void Multiply(TypeIn in0, TypeIn in1, TypeIn in2, const TypeMatrix m[9],
TypeOut* const out0, TypeOut* const out1,
TypeOut* const out2) {
*out0 = m[0] * (TypeOut)in0 + m[1] * (TypeOut)in1 + m[2] * (TypeOut)in2;
*out1 = m[3] * (TypeOut)in0 + m[4] * (TypeOut)in1 + m[5] * (TypeOut)in2;
*out2 = m[6] * (TypeOut)in0 + m[7] * (TypeOut)in1 + m[8] * (TypeOut)in2;
}
// Fixed-point vector/matrix multiplication such as 'out = (m * in) >> shift'.
// Make sure the range of TInternal can hold internal computation.
template <typename TypeIn, typename TypeMatrix, typename TypeOut,
typename TInternal = int32_t>
static void Multiply(TypeIn in0, TypeIn in1, TypeIn in2, const TypeMatrix m[9],
uint32_t shift, TypeOut* const out0, TypeOut* const out1,
TypeOut* const out2) {
TInternal tmp0, tmp1, tmp2;
Multiply(in0, in1, in2, m, &tmp0, &tmp1, &tmp2);
tmp0 = RightShiftRound(tmp0, shift);
tmp1 = RightShiftRound(tmp1, shift);
tmp2 = RightShiftRound(tmp2, shift);
constexpr TInternal kMin = std::numeric_limits<TypeOut>::min();
constexpr TInternal kMax = std::numeric_limits<TypeOut>::max();
*out0 = (TypeOut)Clamp(tmp0, kMin, kMax);
*out1 = (TypeOut)Clamp(tmp1, kMin, kMax);
*out2 = (TypeOut)Clamp(tmp2, kMin, kMax);
}
// Useful macros for dividing by alpha (un-mult)
// for all values v in [0,255] and a in [1,255], we have:
// (v * 255 + a/2) / a == (v * M + 1<<15) >> 16
// where M = ((255u << 16) + a - 1) / a
// ... except for the pair v=207, a=193 !!
// But we're taking the std::min(..., 255u), so it's fine to use 16 bit.
extern const uint32_t kAlphaDiv[256]; // <- ((255u << 16) + a - 1) / a
static inline uint32_t DivByAlphaDiv(uint32_t x, uint32_t mult) {
return (x * mult + (1u << 15)) >> 16;
}
// Same as (v+127)/255 for v up to 255 and error margin of +-1 up to 1024*256.
static inline int32_t DivBy255(int32_t v) { return (v + (v >> 8) + 128) >> 8; }
static inline uint32_t DivBy255(uint32_t v) {
return (v + (v >> 8) + 128) >> 8;
}
// Returns the square root of integer 'v' such as 'r^2 <= v' and '(r+1)^2 > v'.
// Make sure 'v^2' fits in T. Slow func, use sparingly or optimize it.
template <typename T>
T SqrtFloor(T v) {
static_assert(std::is_integral<T>::value, "SqrtFloor() is only for integers");
T lo = std::max((T)0, std::numeric_limits<T>::min()), hi = v;
while (lo < hi) {
const auto mid = (lo + hi + 1) / 2;
if (mid * mid > v) {
hi = mid - 1;
} else {
lo = mid;
}
}
return lo;
}
//------------------------------------------------------------------------------
// Random generator
class PseudoRNG {
public:
PseudoRNG();
// Returns a pseudo-random number in (-amp, amp) range. 'amp' must be > 0.
// (uses D.Knuth's Difference-based random generator).
int32_t GetSigned(int32_t amp) {
assert(amp > 0);
int32_t diff;
do {
diff = Update();
diff = ((int64_t)diff * amp) >> 31;
} while (diff == -amp);
return diff;
}
// returns a pseudo-random number in [0, amp), with amp > 0 (slow!)
uint32_t GetUnsigned(uint32_t amp) {
assert(amp > 0);
uint32_t diff;
do {
diff = (uint32_t)Update();
diff = (((uint64_t)diff * amp) >> 32);
} while (diff >= amp);
return diff;
}
// returns a value in [min, max]
int32_t Get(int32_t min, int32_t max) {
assert(max >= min);
return GetUnsigned((uint32_t)(max + 1 - min)) + min;
}
private:
int32_t Update(); // update state and return uniform 31b-signed number
static constexpr uint32_t kTabSize = 55u;
static const uint32_t kRandomTable[kTabSize];
uint32_t index1_, index2_;
uint32_t tab_[kTabSize];
};
} // namespace WP2
#endif /* WP2_DSP_MATH_H_ */