blob: ba6f8f6f60a073809adc17b011cda964eb50f0fd [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -----------------------------------------------------------------------------
//
// Asymmetric Numeral System coder
//
// Author: Skal (pascal.massimino@gmail.com)
#include "src/utils/ans.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <iterator>
#include "src/dsp/dsp.h"
#include "src/dsp/math.h"
#include "src/utils/data_source.h"
#include "src/utils/utils.h"
#include "src/wp2/base.h"
#include "src/wp2/format_constants.h"
// some derived constants:
#define IO_BYTES (IO_BITS / 8)
#define IO_LIMIT (1ull << IO_LIMIT_BITS)
namespace WP2 {
ANSSymbolInfo ANSSymbolInfo::Rescale(const ANSSymbolInfo& max,
uint32_t tab_size) const {
ANSSymbolInfo info;
const uint32_t max_offset = max.offset + max.freq;
const uint32_t offset_next = (offset + freq) * tab_size / max_offset;
info.offset = offset * tab_size / max_offset;
// Compute the frequency with the next offset to not leave any gap.
info.freq = offset_next - info.offset;
info.symbol = symbol;
assert(offset < max.offset || offset_next == tab_size);
return info;
}
//------------------------------------------------------------------------------
// ANSBinSymbol
ANSBinSymbol::ANSBinSymbol(uint32_t p0, uint32_t p1) {
num_zeros_ = p0;
num_total_ = p0 + p1;
}
void ANSBinSymbol::UpdateCount(uint32_t bit) {
assert(num_total_ < kMaxSum);
num_zeros_ += !bit;
++num_total_;
}
//------------------------------------------------------------------------------
// update-table (hardcoded for APROBA_MAX_SYMBOL == 16)
// For mixing distributions while preserve norms, see:
// https://fgiesen.wordpress.com/2015/02/20/mixing-discrete-probability-distributions/ NOLINT
void ANSAdaptiveSymbol::Update(uint32_t sym) {
const uint32_t mult = adapt_factor_;
if (method_ == Method::kAOM && adapt_factor_ > kMinAOMAdaptSpeed) {
adapt_factor_ -= adapt_factor_ >> kAOMAdaptSpeedShift;
}
// TODO(skal): SSE or table-free for plain-C
assert(sym < max_symbol_);
const uint16_t* const cdf_var = &cdf_var_[APROBA_MAX_SYMBOL - sym - 1];
ANSUpdateCDF(max_symbol_, cdf_base_.data(), cdf_var, mult, cumul_.data());
assert(cumul_[0] == 0);
assert(cumul_[max_symbol_ - 1] < APROBA_MAX);
assert(max_symbol_ == APROBA_MAX_SYMBOL || cumul_[max_symbol_] == APROBA_MAX);
}
//------------------------------------------------------------------------------
static void GenerateVariableCDF(uint32_t nnz, uint16_t* const table) {
for (uint32_t i = 0; i < APROBA_MAX_SYMBOL; ++i) table[i] = 0;
for (uint32_t i = APROBA_MAX_SYMBOL; i < 2 * APROBA_MAX_SYMBOL - 1; ++i) {
table[i] = APROBA_MAX - nnz;
}
}
void ANSAdaptiveSymbol::InitCDF() {
uint32_t nnz = 0;
cumul_[max_symbol_] = APROBA_MAX;
for (uint32_t s = 0; s < max_symbol_; ++s) {
cdf_base_[s] = nnz;
const bool is_used = (cumul_[s + 1] > cumul_[s]);
if (is_used) nnz += 1;
}
for (uint32_t s = max_symbol_; s < APROBA_MAX_SYMBOL; ++s) cdf_base_[s] = nnz;
// TODO(skal): could be static tables, indexed by 'nnz'.
GenerateVariableCDF(nnz, &cdf_var_[0]);
assert(max_symbol_ >= 1);
if (max_symbol_ + 1 <= APROBA_MAX_SYMBOL) {
// It is useful to fill those values for the later search within cumul_.
std::fill(&cumul_[max_symbol_ + 1], &cumul_[APROBA_MAX_SYMBOL] + 1,
APROBA_MAX + 1);
}
}
void ANSAdaptiveSymbol::InitFromUniform(uint32_t max_symbol) {
assert(max_symbol > 0 && max_symbol <= APROBA_MAX_SYMBOL);
max_symbol_ = max_symbol;
for (uint32_t i = 0; i < max_symbol_; ++i) {
cumul_[i] = (uint16_t)(i * APROBA_MAX / max_symbol_);
}
InitCDF();
}
WP2Status ANSAdaptiveSymbol::InitFromCounts(const uint32_t counts[],
uint32_t max_symbol) {
assert(max_symbol <= APROBA_MAX_SYMBOL);
max_symbol_ = 0;
uint32_t pdf[APROBA_MAX_SYMBOL];
std::copy(counts, counts + max_symbol, pdf);
WP2_CHECK_STATUS(ANSNormalizeCounts(pdf, max_symbol, APROBA_MAX));
counts = pdf;
cumul_[0] = 0;
for (uint32_t i = 0; i < max_symbol; ++i) {
if (counts[i] > 0) max_symbol_ = i + 1;
cumul_[i + 1] = cumul_[i] + (uint16_t)counts[i];
}
assert(cumul_[max_symbol_] == APROBA_MAX);
InitCDF();
return WP2_STATUS_OK;
}
WP2Status ANSAdaptiveSymbol::InitFromCDF(const uint16_t* const cdf,
uint32_t max_symbol,
uint32_t max_proba) {
assert(max_symbol > 0 && max_symbol <= APROBA_MAX_SYMBOL);
assert(max_proba > 0);
assert(cdf[0] == 0);
max_symbol_ = max_symbol;
uint32_t counts[APROBA_MAX_SYMBOL];
for (uint32_t i = 0; i < max_symbol - 1; ++i) {
assert(cdf[i + 1] >= cdf[i] && cdf[i + 1] <= max_proba);
counts[i] = cdf[i + 1] - cdf[i];
}
counts[max_symbol - 1] = max_proba - cdf[max_symbol - 1];
WP2_CHECK_STATUS(InitFromCounts(counts, max_symbol));
return WP2_STATUS_OK;
}
void ANSAdaptiveSymbol::CopyFrom(const ANSAdaptiveSymbol& other) {
max_symbol_ = other.max_symbol_;
adapt_factor_ = other.adapt_factor_;
std::copy(std::begin(other.cumul_), std::end(other.cumul_),
std::begin(cumul_));
std::copy(std::begin(other.cdf_base_), std::end(other.cdf_base_),
std::begin(cdf_base_));
std::copy(std::begin(other.cdf_var_), std::end(other.cdf_var_),
std::begin(cdf_var_));
}
void ANSAdaptiveSymbol::SetAdaptationSpeed(Method method, uint32_t speed) {
method_ = method;
if (method_ == Method::kAOM) {
assert(speed == kANSAProbaInvalidSpeed);
adapt_factor_ = kMaxAOMAdaptSpeed;
} else {
assert(speed != kANSAProbaInvalidSpeed);
adapt_factor_ = std::min(speed, 0xffffu);
}
}
void ANSAdaptiveSymbol::Print(bool use_percents) const {
printf(" -- adaptive [max_symbol:%d] (speed factor=%u): ", max_symbol_,
adapt_factor_);
if (use_percents) {
for (uint32_t i = 0; i < APROBA_MAX_SYMBOL; ++i) {
printf("[%5.2f%%]", 100.f * GetProba(i));
}
} else {
for (uint32_t i = 0; i <= APROBA_MAX_SYMBOL; ++i) {
printf("[0x%.4x]", cumul_[i]);
}
}
printf("\n");
}
void ANSAdaptiveSymbol::PrintProbas(uint32_t norm) const {
if (norm == 0) norm = APROBA_MAX;
uint32_t pdf[APROBA_MAX_SYMBOL];
for (uint32_t i = 0; i < APROBA_MAX_SYMBOL; ++i) {
pdf[i] = cumul_[i + 1] - cumul_[i];
}
if (norm != APROBA_MAX) {
(void)ANSNormalizeCounts(pdf, APROBA_MAX_SYMBOL, APROBA_MAX);
}
printf(" { ");
for (uint32_t i = 0; i < APROBA_MAX_SYMBOL; ++i) {
if (i > 0) printf(", ");
printf("%d", pdf[i]);
}
printf(" },\n");
}
//------------------------------------------------------------------------------
// Decoder
#if defined(WP2_BITTRACE)
constexpr double ANSDec::kBitCountAccuracy;
#endif
uint32_t ANSDec::ReadNextWord(uint32_t s) {
s <<= IO_BITS;
const uint8_t* data;
if (data_source_->TryReadNext(IO_BYTES, &data)) {
s |= (data[1] << 8) | data[0]; // written as endian-neutral code
} else {
status_ = WP2_STATUS_NOT_ENOUGH_DATA;
}
return s;
}
uint32_t ANSDec::ReadBitInternal(uint32_t num_zeros, uint32_t num_total) {
if (num_zeros == 0) {
return 1;
} else if (num_zeros == num_total) {
return 0;
} else {
const uint32_t p0 = (num_zeros << PROBA_BITS) / num_total;
const int q0 = PROBA_MAX - p0;
const uint32_t xfrac = state_ & (PROBA_MAX - 1);
const uint32_t bit = (xfrac >= p0);
if (!bit) {
state_ = p0 * (state_ >> PROBA_BITS) + xfrac;
} else {
state_ = q0 * (state_ >> PROBA_BITS) + xfrac - p0;
}
if (state_ < IO_LIMIT) state_ = ReadNextWord(state_);
#if defined(WP2_BITTRACE)
cur_pos_ += PROBA_BITS - WP2Log2(bit ? q0 : p0);
#endif
return bit;
}
}
uint32_t ANSDec::ReadABitInternal(ANSBinSymbol* const stats) {
const uint32_t bit = ReadBitInternal(stats->NumZeros(), stats->NumTotal());
return stats->Update(bit);
}
uint32_t ANSDec::ReadSymbolInternal(const ANSSymbolInfo codes[],
uint32_t log2_tab_size,
uint32_t max_index) {
const uint32_t tab_size = (1u << log2_tab_size);
const uint32_t res = state_ & (tab_size - 1);
// Deduce the scaled freq/offset.
ANSSymbolInfo s;
// A good start is the scaled index. It is almost always the right value,
// except for very small frequencies, because of rounding errors.
const uint32_t max_pos = max_index + codes[max_index].freq;
uint32_t i = res * max_pos / tab_size;
i -= codes[i].offset;
uint32_t offset_next;
while (true) {
offset_next = (i + codes[i].freq) * tab_size / max_pos;
if (res < offset_next) break;
i += codes[i].freq;
}
s.offset = i * tab_size / max_pos;
s.freq = offset_next - s.offset;
s.symbol = codes[i].symbol;
assert(res >= s.offset && res < offset_next);
state_ = s.freq * (state_ >> log2_tab_size) + (res - s.offset);
if (state_ < IO_LIMIT) state_ = ReadNextWord(state_);
#if defined(WP2_BITTRACE)
cur_pos_ += log2_tab_size - WP2Log2(s.freq);
#endif
return s.symbol;
}
uint32_t ANSDec::ReadSymbolInternal(const ANSAdaptiveSymbol& asym,
uint32_t max_index) {
const uint32_t res = state_ & ((1u << ANS_LOG_TAB_SIZE) - 1);
// Deduce the scaled freq/offset.
// TODO(maryla): optimize so we don't do a linear scan from he first symbol.
ANSSymbolInfo s;
s.offset = s.freq = s.symbol = 0;
const ANSSymbolInfo info_max = asym.GetInfo(max_index);
for (uint32_t i = 0u; i <= max_index; ++i) {
s = asym.GetInfo(i).Rescale(info_max);
if (res < (uint32_t)s.offset + s.freq) break;
}
assert(res >= s.offset);
state_ = s.freq * (state_ >> ANS_LOG_TAB_SIZE) + (res - s.offset);
if (state_ < IO_LIMIT) state_ = ReadNextWord(state_);
#if defined(WP2_BITTRACE)
cur_pos_ += ANS_LOG_TAB_SIZE - WP2Log2(s.freq);
#endif
return s.symbol;
}
uint32_t ANSDec::ReadSymbolInternal(const ANSSymbolInfo codes[],
uint32_t log2_tab_size) {
const uint32_t res = state_ & ((1u << log2_tab_size) - 1);
const ANSSymbolInfo s = codes[res];
state_ = s.freq * (state_ >> log2_tab_size) + s.offset;
if (state_ < IO_LIMIT) state_ = ReadNextWord(state_);
#if defined(WP2_BITTRACE)
cur_pos_ += log2_tab_size - WP2Log2(s.freq);
#endif
return s.symbol;
}
uint32_t ANSDec::ReadSymbolInternal(const ANSAdaptiveSymbol& asym) {
const uint32_t proba = state_ & (APROBA_MAX - 1);
const ANSSymbolInfo s = asym.GetSymbol(proba);
state_ = s.freq * (state_ >> APROBA_BITS) + (proba - s.offset);
if (state_ < IO_LIMIT) state_ = ReadNextWord(state_);
#if defined(WP2_BITTRACE)
cur_pos_ += ANSAdaptiveSymbol::GetFreqCost(s.freq);
#endif
return s.symbol;
}
uint32_t ANSDec::ReadUValueInternal(uint32_t bits) {
assert(bits <= kANSMaxUniformBits);
const uint32_t value = state_ & ((1u << bits) - 1);
state_ >>= bits;
if (state_ < IO_LIMIT) state_ = ReadNextWord(state_);
#if defined(WP2_BITTRACE)
cur_pos_ += bits;
#endif
return value;
}
uint32_t ANSDec::ReadRValueInternal(uint32_t range) {
assert(range <= (1u << kANSMaxRangeBits));
uint32_t s = state_ % range;
state_ = state_ / range;
if (state_ < IO_LIMIT) {
s = ReadNextWord(s);
state_ = (state_ << IO_LIMIT_BITS) + s / range;
s %= range;
}
#if defined(WP2_BITTRACE)
cur_pos_ += WP2Log2(range);
#endif
return s;
}
void ANSDec::Init(DataSource* const data_source) {
status_ = WP2_STATUS_OK;
data_source_ = data_source;
state_ = 0;
if (data_source == nullptr) {
status_ = WP2_STATUS_NULL_PARAMETER;
} else {
state_ = ReadNextWord(state_);
state_ = ReadNextWord(state_);
}
if (status_ == WP2_STATUS_OK && state_ < IO_LIMIT) {
status_ = WP2_STATUS_BITSTREAM_ERROR;
}
}
//------------------------------------------------------------------------------
// Bit-counters
#if defined(WP2_ENC_DEC_MATCH)
uint32_t ANSDec::ReadDebugUValue(uint32_t bits) {
return ReadUValueInternal(bits);
}
#endif
#if defined(WP2_BITTRACE)
void ANSDec::BitTrace(uint32_t value, LabelStats::Type type,
const char label[]) {
if (!is_padding_counted_) {
const double incr = kANSPaddingCost;
counters_["ANS_padding"].bits += incr;
cur_pos_ += incr;
last_pos_ += incr;
counters_["ANS_padding"].num_occurrences += 1;
is_padding_counted_ = true;
}
const std::string key = debug_prefix_ + std::string(label);
counters_[key].bits += cur_pos_ - last_pos_;
++counters_[key].num_occurrences;
++counters_[key].histo[value];
counters_[key].type = type;
// Add stats for the custom counter (even if no custom key, not to miss any).
if (merge_custom_until_pop_) {
assert(!counters_custom_key_.empty());
counters_custom_[counters_custom_key_].bits += cur_pos_ - last_pos_;
} else {
const std::string counters_custom_key =
counters_custom_key_.empty() ? key
: (counters_custom_key_ + "/" + label);
counters_custom_[counters_custom_key].bits += cur_pos_ - last_pos_;
++counters_custom_[counters_custom_key].num_occurrences;
}
last_pos_ = cur_pos_;
}
void ANSDec::PushBitTracesCustomPrefix(const char* const suffix_of_prefix,
bool merge_until_pop) {
assert(std::strlen(suffix_of_prefix) > 0);
if (!counters_custom_key_.empty()) counters_custom_key_ += "/";
counters_custom_key_ += suffix_of_prefix;
assert(!merge_custom_until_pop_);
if (merge_until_pop) {
merge_custom_until_pop_ = true;
// Increment by one now, merge next ones.
++counters_custom_[counters_custom_key_].num_occurrences;
}
}
void ANSDec::PopBitTracesCustomPrefix(const char* const suffix_of_prefix) {
if (counters_custom_key_ == suffix_of_prefix) {
counters_custom_key_.clear();
} else {
const uint32_t suffix_length = std::strlen(suffix_of_prefix);
assert(counters_custom_key_.size() > suffix_length);
const uint32_t new_length = counters_custom_key_.size() - suffix_length - 1;
assert(counters_custom_key_[new_length] == '/' &&
counters_custom_key_.substr(new_length + 1) == suffix_of_prefix);
counters_custom_key_.resize(new_length);
}
merge_custom_until_pop_ = false;
}
#endif
uint32_t ANSDec::GetNumUsedBytes() const {
return data_source_->GetNumDiscardedBytes() + data_source_->GetNumReadBytes();
}
uint32_t ANSDec::GetMinNumUsedBytesDiff(uint32_t num_used_bytes_before,
uint32_t num_used_bytes_after) {
// Precision is +-IO_BYTES.
return SafeSub(num_used_bytes_after, num_used_bytes_before + IO_BYTES);
}
//------------------------------------------------------------------------------
} // namespace WP2