blob: 9c6bf496ea4468cfbd9d90db52e8023f3115c40d [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -----------------------------------------------------------------------------
//
// Asymmetric Numeral System coder
//
// Author: Skal (pascal.massimino@gmail.com)
#ifndef WP2_UTILS_ANS_H_
#define WP2_UTILS_ANS_H_
#include <array>
#include <cassert>
#include <cstdint>
#include "src/dsp/dsp.h"
#include "src/wp2/base.h"
#include "src/wp2/debug.h"
#ifdef HAVE_CONFIG_H
#include "src/wp2/config.h"
#endif
#include "src/dsp/math.h"
#include "src/utils/utils.h"
#include "src/utils/vector.h"
#define WP2_OPT_LABEL const char label[]
namespace WP2 {
class DataSource;
#if defined(WP2_BITTRACE)
// For each label, store the amount of bits, and the number of occurrences.
// (can be useful whether WP2_BITTRACE is defined or not)
typedef std::map<const std::string, LabelStats> BitCounts;
#endif
// When uncommenting the following, the 'label' string used when inputting
// anything in the ANS encoder is also inserted, as a hash or as the whole
// string if WP2_ENC_DEC_DEEP_MATCH is defined. When decoding the ANS
// stream, it is asserted that the same string is used to make sure that the
// same information is read in the same order.
// #define WP2_ENC_DEC_MATCH
#define ANS_LOG_TAB_SIZE 14
#define ANS_TAB_SIZE (1u << ANS_LOG_TAB_SIZE)
// maximum numbers of symbols for dictionaries and tables
#define ANS_MAX_SYMBOLS 1024 // 10 bits
// bits have a probability within [0, PROBA_MAX]
#define PROBA_BITS 16
#define PROBA_MAX (1u << PROBA_BITS)
// how many bits used during I/O between state_ and bitstream
#define IO_BITS 16
#define IO_LIMIT_BITS 16
#define ANS_SIGNATURE (0xf3 * IO_LIMIT) // Initial state, used as CRC.
static_assert(IO_BITS == 16, "IO_BITS should be == 16");
static constexpr uint32_t kANSPaddingCost = (IO_LIMIT_BITS + IO_BITS);
// Max number (inclusive) of bits for the range used during range-value coding.
static constexpr uint32_t kANSMaxRangeBits = 14;
// Maximum range that can be used. As a range is always > 0, it is stored -1. We
// can hence go up to (1u << kANSMaxRangeBits) included so that the storage fits
// on kANSMaxRangeBits bits.
static constexpr uint32_t kANSMaxRange = (1u << kANSMaxRangeBits);
// Max number (inclusive) of bits to use to describe a uniform value.
static constexpr uint32_t kANSMaxUniformBits = IO_BITS;
// Adaptive symbol with a small dictionary
// APROBA_BITS must be a little less precision than PROBA_BITS in order to fit
// the range in the ANSTokenInfoASym as 14+14 bits.
#define APROBA_BITS 14
#define APROBA_MAX (1 << APROBA_BITS)
#define APROBA_ADAPT_SPEED 4096 // must be strictly less than 32768
static constexpr uint32_t kANSAProbaInvalidSpeed = (1 << PROBA_BITS);
// some derived constants:
#define IO_BYTES (IO_BITS / 8)
#define IO_LIMIT (1ull << IO_LIMIT_BITS)
// TODO(vrabaud): make sure the ANS read/write functions are optimized/inlined.
//------------------------------------------------------------------------------
// Decoding
struct ANSSymbolInfo { // Symbol information.
uint16_t offset; // offset to the first ANSSymbolInfo of that symbol
uint16_t freq; // frequency in [0, PROBA_MAX]
uint16_t symbol;
ANSSymbolInfo Rescale(const ANSSymbolInfo& max_symbol,
uint32_t tab_size = ANS_TAB_SIZE) const;
};
typedef VectorNoCtor<ANSSymbolInfo> ANSCodes;
// Struct for recording binary event and updating probability.
// TODO(maryla): merge this into ANSAdaptiveSymbol to make them simpler to use?
class ANSBinSymbol {
public:
explicit ANSBinSymbol(uint32_t p0 = 1, uint32_t p1 = 1);
ANSBinSymbol(ANSBinSymbol&&) = default;
ANSBinSymbol(const ANSBinSymbol& other) = default;
ANSBinSymbol& operator=(const ANSBinSymbol&) = default;
// update observation and return 'bit'
inline uint32_t Update(uint32_t bit) {
if (num_total_ < kMaxSum) UpdateCount(bit);
return bit;
}
// Returns the number of occurrences 0 happened.
uint32_t NumZeros() const { return num_zeros_; }
// Returns the number of occurrences of the bit.
uint32_t NumTotal() const { return num_total_; }
// cost of storing the given 'bit'.
float GetCost(uint32_t bit) const {
if (num_zeros_ == 0 || num_zeros_ == num_total_) return 0;
return WP2Log2(num_total_) -
WP2Log2(bit ? num_total_ - num_zeros_ : num_zeros_);
}
private:
// we are dealing with quite stationary sources, so updating the proba
// past kMaxSum events is usually irrelevant.
static constexpr uint32_t kMaxSum = 256u;
uint16_t num_zeros_, num_total_;
void UpdateCount(uint32_t bit);
};
// Decoder only:
// Generate a 'flat' spread table from a set of frequencies. The counts[]
// distributions is re-normalized such that its sum is equal to 'table_size'.
// 'codes' will be resized to 'table_size' if needed.
// Returns ok if there are no ill-defined counts.
WP2Status ANSCountsToSpreadTable(uint32_t counts[], uint32_t max_symbol,
uint32_t table_size, ANSCodes& codes);
// Analyze counts[0..size) and renormalize it so that the total is equal to
// 'sum' exactly.
WP2Status ANSNormalizeCounts(uint32_t counts[], uint32_t size, uint32_t sum);
// Stores and adapts small APROBA_MAX_SYMBOL-dictionary
class ANSAdaptiveSymbol {
public:
enum class Method {
// Adaptation is constant with a speed in kAdaptationSpeeds.
kConstant,
// Adaptation is exactly the AOM one, with a rate depending on how many
// symbols got already written.
kAOM,
kNum
};
ANSAdaptiveSymbol() : method_(Method::kNum) {}
// Initializes with a uniform distribution.
void InitFromUniform(uint32_t max_symbol);
// Initializes from an un-normalized pdf. Norm is APROBA_MAX.
// Returns false if the distribution can't be normalized.
WP2Status InitFromCounts(const uint32_t counts[], uint32_t max_symbol);
// Initializes from a cdf, normalization to APROBA_MAX is performed if the CDF
// is normalized to max_proba.
WP2Status InitFromCDF(const uint16_t* cdf, uint32_t max_symbol,
uint32_t max_proba = APROBA_MAX);
void CopyFrom(const ANSAdaptiveSymbol& other);
// Adaptation speed goes from 0x3fffu (fast adaptation) to 0 (no adaptation).
// Special case is speed 0xffffu (instant adaptation)
void SetAdaptationSpeed(Method method,
uint32_t speed = kANSAProbaInvalidSpeed);
// proba is in [0, APROBA_MAX) range
ANSSymbolInfo GetSymbol(uint32_t proba) const {
return GetInfo(ANSGetSymbol(max_symbol_, cumul_.data(), proba));
}
// Convenience representation of ranges as ANSSymbolInfo.
// 'sym' must be in [0, max_symbol_) range.
inline ANSSymbolInfo GetInfo(uint32_t sym) const {
assert(sym < max_symbol_);
ANSSymbolInfo info;
info.symbol = (uint16_t)sym;
info.offset = cumul_[sym];
info.freq = cumul_[sym + 1] - info.offset;
return info;
}
// Updates the cumulative distribution after coding symbol 'sym',
// preserving the norm equal to APROBA_MAX.
void Update(uint32_t sym);
// Returns proba = (cumul[sym + 1] - cumul[sym]) / APROBA_MAX
float GetProba(uint32_t sym) const {
assert(sym < max_symbol_);
const uint32_t w = cumul_[sym + 1] - cumul_[sym];
return (1.f / APROBA_MAX) * w;
}
float GetCost(uint32_t sym) const {
return kCostTable[cumul_[sym + 1] - cumul_[sym]];
}
float GetCost(uint32_t sym, uint32_t max_symbol) const {
assert(sym <= max_symbol);
return kCostTable[cumul_[sym + 1] - cumul_[sym]] -
kCostTable[cumul_[max_symbol + 1] - cumul_[0]];
}
// debug
void Print(bool use_percents = true) const;
void PrintProbas(uint32_t norm = 0) const;
// Computes the 'distance' between two cdf. Can be used to
// monitor convergence to final stationary distribution.
float ComputeDistance(const ANSAdaptiveSymbol& from) const;
// Distance (in [0..1]) based on variance.
float ComputeNonUniformDistance() const;
// Will return the optimal speed index to use to reach the final distribution.
// The index refers to kAdaptationSpeeds[].
// If header_bits is not nullptr, and in case a fixed distribution is better
// than an adaptive one (ie.: adaptation speed looks too fast), will return
// the number of header-bits that can be spent to code the fixed distribution.
// If fixed-proba is not advantageous, 0 is returned as header_bits.
void FindBestAdaptationSpeed(const Vector_u8& syms, Method* method,
uint32_t* speed,
uint32_t* header_bits = nullptr) const;
// Evaluates the cost of the sequence of symbols for the given
// adaptation speed.
float ScoreSequence(Method method, uint32_t speed,
const Vector_u8& syms) const;
const uint16_t* GetCumul() const { return cumul_.data(); }
uint32_t NumSymbols() const { return max_symbol_; }
static inline float GetFreqCost(uint32_t freq) { return kCostTable[freq]; }
protected:
Method method_;
uint32_t max_symbol_; // At most APROBA_MAX_SYMBOL.
uint32_t adapt_factor_ = APROBA_ADAPT_SPEED;
// cumulative frequency. cumul[0] is always 0, cumul[last] is always
// APROBA_MAX.
std::array<uint16_t, APROBA_MAX_SYMBOL + 1> cumul_;
void InitCDF(); // fills cdf_base_[] and cdf_var_[] according to cumul_[]
std::array<uint16_t, APROBA_MAX_SYMBOL> cdf_base_;
std::array<uint16_t, APROBA_MAX_SYMBOL * 2 - 1> cdf_var_;
// precalc for -log2(GetProba())
// TODO(skal): Costs should be normalized to int16_t, not float.
static const float kCostTable[APROBA_MAX + 1];
};
#if defined(WP2_ENC_DEC_MATCH)
#if defined(WP2_ENC_DEC_DEEP_MATCH)
template <typename... Ts>
inline std::string ANSString(const std::string& debug_prefix,
const char label[], Ts... extra) {
std::string str = debug_prefix;
str += label;
const uint16_t extra_arr[] = {(const uint16_t)extra...};
for (uint16_t e : extra_arr) str += ", " + std::to_string(e);
return str;
}
#else
// djb2 string hash
template <typename... Ts>
inline uint16_t ANSStringHash(const std::string& debug_prefix,
const char label[], Ts... extra) {
uint16_t hash = 5381;
for (char c : debug_prefix) hash = ((hash << 5) + hash) + c;
for (int i = 0; label[i] != 0; ++i) hash = ((hash << 5) + hash) + label[i];
const uint16_t extra_arr[] = {(const uint16_t)extra...};
for (uint16_t e : extra_arr) hash = ((hash << 5) + hash) + e;
return hash;
}
#endif // defined(WP2_ENC_DEC_DEEP_MATCH)
#endif // defined(WP2_ENC_DEC_MATCH)
// Class for holding the decoding state.
class ANSDec {
public:
explicit ANSDec(DataSource* const data_source) { Init(data_source); }
ANSDec(const ANSDec&) = delete;
// initializes a new ANSDec object.
void Init(DataSource* data_source);
// Decodes a symbol, according to the spread table 'codes'.
uint32_t ReadSymbol(const ANSSymbolInfo codes[], uint32_t log2_tab_size,
WP2_OPT_LABEL) {
const uint32_t symbol = ReadSymbolInternal(codes, log2_tab_size);
Trace("%s: symbol=%u", label, symbol);
BitTrace(symbol, LabelStats::Type::Symbol, label);
return symbol;
}
// Decodes a symbol, according to the info.
// 'max_index' is the index in 'codes' for which we know for sure the read
// value is not strictly superior.
uint32_t ReadSymbol(const ANSSymbolInfo codes[], uint32_t log2_tab_size,
uint32_t max_index, WP2_OPT_LABEL) {
const uint32_t symbol = ReadSymbolInternal(codes, log2_tab_size, max_index);
Trace("%s: symbol=%u", label, symbol);
BitTrace(symbol, LabelStats::Type::Symbol, label);
return symbol;
}
// Decodes a symbol from small adaptive dictionary.
// 'asym' probability is adapted.
uint32_t ReadASymbol(ANSAdaptiveSymbol* const asym, WP2_OPT_LABEL) {
const uint32_t symbol = ReadSymbol(*asym, label);
asym->Update(symbol);
return symbol;
}
// Decodes a symbol from small adaptive dictionary.
// 'asym' probability is *NOT* adapted.
uint32_t ReadSymbol(const ANSAdaptiveSymbol& asym, WP2_OPT_LABEL) {
const uint32_t symbol = ReadSymbolInternal(asym);
Trace("%s: symbol=%u", label, symbol);
BitTrace(symbol, LabelStats::Type::ASymbol, label);
return symbol;
}
uint32_t ReadSymbol(const ANSAdaptiveSymbol& asym, uint32_t max_index,
WP2_OPT_LABEL) {
const uint32_t symbol = ReadSymbolInternal(asym, max_index);
Trace("%s: symbol=%u", label, symbol);
BitTrace(symbol, LabelStats::Type::ASymbol, label);
return symbol;
}
// Decodes a binary symbol with probability defined by the number of 0's and
// total occurrences.
uint32_t ReadBit(uint32_t num_zeros, uint32_t num_total, WP2_OPT_LABEL) {
const uint32_t bit = ReadBitInternal(num_zeros, num_total);
Trace("%s: bit=%u num_zeros=%d num_total=%d", label, bit, num_zeros,
num_total);
BitTrace(bit, LabelStats::Type::Bit, label);
return bit;
}
// Decodes an adaptive binary symbol with statistics 'stats'.
// 'stats' is updated upon return.
uint32_t ReadABit(ANSBinSymbol* const stats, WP2_OPT_LABEL) {
const uint32_t bit = ReadABitInternal(stats);
Trace("%s: abit=%u", label, bit);
BitTrace(bit, LabelStats::Type::ABit, label);
return bit;
}
// Decodes a bool with a fifty-fifty probability.
inline bool ReadBool(WP2_OPT_LABEL) { return (ReadUValue(1, label) != 0); }
// Decodes a uniform value known to be in range [0..1 << bits), with 'bits'
// in [0, IO_BITS] range.
// If 'bits' == 0, it returns 0.
uint32_t ReadUValue(uint32_t bits, WP2_OPT_LABEL) {
const uint32_t value = ReadUValueInternal(bits);
Trace("%s: value=0x%x bits=%u", label, value, bits);
BitTrace(value, LabelStats::Type::UValue, label);
return value;
}
// Same as ReadUValue() with signed values in [ -2^(bits-1) .. 2^(bits-1) ).
inline int32_t ReadSUValue(uint32_t bits, WP2_OPT_LABEL) {
return (int32_t)ReadUValue(bits, label) - ((1 << (bits)) >> 1);
}
// Decodes a uniform value known to be in [0..range), an interval fitting in
// kANSMaxRangeBits bits.
uint32_t ReadRValue(uint32_t range, WP2_OPT_LABEL) {
const uint32_t value = ReadRValueInternal(range);
Trace("%s: value=0x%x range=0x%x", label, value, range);
BitTrace(value, LabelStats::Type::RValue, label);
return value;
}
// Decodes a uniform value known to be in [min..max], an interval fitting in
// kANSMaxRangeBits bits.
inline int32_t ReadRange(int32_t min, int32_t max, WP2_OPT_LABEL) {
assert(min <= max);
return ReadRValue(max - min + 1, label) + min;
}
// Returns true if no error occurred.
WP2Status GetStatus() const { return status_; }
// Add something to the prefix that will appear when adding bits.
void AddDebugPrefix(const char prefix[]) {
#if defined(WP2_BITTRACE) || defined(WP2_TRACE) || defined(WP2_ENC_DEC_MATCH)
debug_prefix_ += prefix;
#if defined(WP2_BITTRACE)
counters_[kPrefixStr + debug_prefix_].bits = 0.f;
++counters_[kPrefixStr + debug_prefix_].num_occurrences;
#endif
debug_prefix_ += "/";
#else
(void)prefix;
#endif
}
// Pop the latest addition to the prefix.
void PopDebugPrefix() {
#if defined(WP2_BITTRACE) || defined(WP2_TRACE) || defined(WP2_ENC_DEC_MATCH)
assert(!debug_prefix_.empty());
int i = debug_prefix_.size() - 2;
while (i >= 0 && debug_prefix_[i] != '/') --i;
debug_prefix_.erase(i + 1);
#endif
}
private:
// Logs something and makes sure the label is the same.
template <typename... Ts>
inline void Trace(const char* format, const char label[], Ts... extra) {
(void)sizeof...(extra);
(void)label;
WP2Trace(format, debug_prefix_, label, extra...);
#if defined(WP2_ENC_DEC_MATCH)
#if defined(WP2_ENC_DEC_DEEP_MATCH)
const std::string str = ANSString(debug_prefix_, label, extra...);
std::string str_read(ReadDebugUValue(8), '.');
for (char& c : str_read) c = (char)ReadDebugUValue(8);
if (str_read != str) assert(false);
#else
const uint16_t hash = ANSStringHash(debug_prefix_, label, extra...);
const uint16_t hash_read = ReadDebugUValue(16);
if (hash_read != hash) assert(false);
#endif // defined(WP2_ENC_DEC_DEEP_MATCH)
#endif // defined(WP2_ENC_DEC_MATCH)
}
// internal ReadXXX versions without label[]
uint32_t ReadSymbolInternal(const ANSSymbolInfo codes[],
uint32_t log2_tab_size);
uint32_t ReadSymbolInternal(const ANSSymbolInfo codes[],
uint32_t log2_tab_size, uint32_t max_index);
uint32_t ReadSymbolInternal(const ANSAdaptiveSymbol& asym);
uint32_t ReadSymbolInternal(const ANSAdaptiveSymbol& asym,
uint32_t max_index);
uint32_t ReadRValueInternal(uint32_t range);
uint32_t ReadUValueInternal(uint32_t bits);
uint32_t ReadABitInternal(ANSBinSymbol* stats);
uint32_t ReadBitInternal(uint32_t num_zeros, uint32_t num_total);
// Shift 'state' and read-in the lower bits. Return the new value.
// Typically called as: 'state_ = ReadNextWord(state_);'
inline uint32_t ReadNextWord(uint32_t state);
DataSource* data_source_;
uint32_t state_;
WP2Status status_;
#if defined(WP2_ENC_DEC_MATCH)
uint32_t ReadDebugUValue(uint32_t bits);
#endif
#if defined(WP2_BITTRACE) || defined(WP2_TRACE) || defined(WP2_ENC_DEC_MATCH)
std::string debug_prefix_;
#endif
public:
// Returns the number of bytes read from the bitstream.
uint32_t GetNumUsedBytes() const;
// Returns a lower bound of the number of bytes completely decoded between
// 'num_bytes_before' and 'num_bytes_after'.
static uint32_t GetMinNumUsedBytesDiff(uint32_t num_used_bytes_before,
uint32_t num_used_bytes_after);
protected:
#if defined(WP2_BITTRACE)
double last_pos_ = 0.;
double cur_pos_ = 0.;
BitCounts counters_;
// Next reads will be registered in 'counters_custom_' with this key.
std::string counters_custom_key_;
// Custom counter on which the user has write access (e.g. to
// initialize/clean).
BitCounts counters_custom_;
// If true, merge custom bit traces until next Pop().
bool merge_custom_until_pop_ = false;
// Tells whether the padding has been added to counters_ yet.
bool is_padding_counted_ = false;
void BitTrace(uint32_t value, LabelStats::Type type, WP2_OPT_LABEL);
public:
// Precision of GetBitCount() compared to the actual used bits is at most
// kANSPaddingCost or 0.16%.
static constexpr double kBitCountAccuracy = 0.0016;
// This string will be prepended to debug prefixes when added to the ANSDec
// to keep track of their call number. It should be a string the user
// will not use in a prefix.
static constexpr const char* kPrefixStr = "____BITTRACE____";
const BitCounts& GetBitTraces() const { return counters_; }
BitCounts& GetBitTracesCustom() { return counters_custom_; }
void ClearBitTracesCustom() { counters_custom_.clear(); }
const std::string& GetBitTracesCustomKey() const {
return counters_custom_key_;
}
void PushBitTracesCustomPrefix(const char* const suffix_of_prefix,
bool merge_until_pop = false);
void PopBitTracesCustomPrefix(const char* const suffix_of_prefix);
double GetBitCount() const { return cur_pos_; }
#else
static inline void BitTrace(uint32_t value, LabelStats::Type type,
WP2_OPT_LABEL) {
(void)value;
(void)type;
(void)label;
}
public:
void PushBitTracesCustomPrefix(const char* const, bool = false) {}
void PopBitTracesCustomPrefix(const char* const) {}
double GetBitCount() const { return 0.f; }
#endif
};
//------------------------------------------------------------------------------
} // namespace WP2
#endif // WP2_UTILS_ANS_H_