blob: 18680c5c4c8f9baca36c2f433609b0d47ce016fc [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -----------------------------------------------------------------------------
//
// Decoding of ANS symbols.
//
// Author: Vincent Rabaud (vrabaud@google.com)
#include "src/dec/symbols_dec.h"
#include <algorithm>
#include <numeric>
#include "src/dsp/math.h"
#include "src/utils/ans_utils.h"
namespace WP2 {
// 'nnz' is the number of non-zero values in the histogram.
// 'max_count' is the maximum value of a count in the histogram.
static WP2Status ReadHistogram(uint32_t nnz, uint32_t symbol_range,
uint32_t max_count, ANSDec* const dec,
VectorNoCtor<ANSSymbolInfo>& infos) {
// Read the mappings first.
Vector_u16 mapping;
uint32_t histo_size = nnz;
bool is_sparse;
if (nnz < symbol_range) {
is_sparse = dec->ReadBool("is_sparse");
if (is_sparse) {
WP2_CHECK_STATUS(LoadMapping(dec, nnz, symbol_range, mapping));
} else {
histo_size = dec->ReadRange(nnz, symbol_range, "histogram_size");
}
} else {
// If we use the whole range, actually force sparse usage. This seems
// counter-intuitive but it's actually due to the optimization done for
// sparse data: 1 is subtracted to all counts[] values.
is_sparse = true;
}
Vector_u16 counts;
WP2_CHECK_ALLOC_OK(counts.resize(histo_size));
// 'type' is 0 for raw probabilities, 1 for Huffman-represented
// probabilities and 2 for Huffman-represented probabilities compressed
// recursively.
const uint8_t type = dec->ReadRValue(3, "coding_type");
const uint32_t max_count_bits =
std::min(kMaxFreqBits, 1u + (uint32_t)WP2Log2Floor(max_count));
if (type == 2) { // HuffmanANS
// Read the probabilities we depend upon.
VectorNoCtor<ANSSymbolInfo> infos_sub;
const uint32_t size_sub = dec->ReadRange(1, 1 + kMaxFreqBits, "size");
WP2_CHECK_STATUS(
ReadHistogram(size_sub, 1 + kMaxFreqBits, histo_size, dec, infos_sub));
// Deduce the counts.
uint16_t max_symbol = 0;
for (const auto& i : infos_sub) {
max_symbol = std::max(max_symbol, i.symbol);
}
// Count the symbols.
Vector_u32 counts_sub;
WP2_CHECK_ALLOC_OK(counts_sub.resize(max_symbol + 1));
std::fill(counts_sub.begin(), counts_sub.end(), 0);
for (const auto& i : infos_sub) counts_sub[i.symbol] = i.freq;
// Build the spread table.
VectorNoCtor<ANSSymbolInfo> codes;
WP2_CHECK_STATUS(ANSCountsToSpreadTable(
counts_sub.data(), counts_sub.size(), ANS_LOG_TAB_SIZE, codes));
// Read the symbols that make up the final stats.
uint32_t nnz_left = nnz;
for (uint32_t i = 0; i < histo_size; ++i) {
if (nnz_left > 0) {
counts[i] = dec->ReadSymbol(codes.data(), ANS_LOG_TAB_SIZE, "counts");
} else {
counts[i] = 0; // Only zeros left.
}
// If we only have non-zeros left (or in sparse mode), counts were
// stored - 1.
if (is_sparse || (i + nnz_left == histo_size)) ++counts[i];
if (counts[i] != 0) --nnz_left;
}
} else {
// Read the probabilities.
uint32_t max_freq_bits;
if (type == 1) { // Huffman
max_freq_bits =
dec->ReadRange(1, 1 + WP2Log2Floor(max_count_bits), "max_freq_bits");
} else { // Raw
max_freq_bits = dec->ReadRange(1, max_count_bits, "max_freq_bits");
}
if (is_sparse) {
ReadVector(dec, (1 << max_freq_bits) - 1, counts);
for (auto& c : counts) c += 1;
} else {
ReadVectorNnz(dec, nnz, (1 << max_freq_bits) - 1, counts);
}
}
// Store everything in the info.
WP2_CHECK_ALLOC_OK(infos.resize(nnz));
const bool do_huffman = (type == 1 || type == 2);
for (uint32_t ind = 0, k = 0; k < counts.size(); ++k) {
if (counts[k] > 0) {
assert(ind < nnz); // We cannot have more non-zeros than needed.
infos[ind].symbol = mapping.empty() ? k : mapping[k];
infos[ind].freq = do_huffman ? (1 << (counts[k] - 1)) : counts[k];
++ind;
}
}
return dec->GetStatus();
}
//------------------------------------------------------------------------------
static bool ANSSymbolInfoCmp(const ANSSymbolInfo& i1, const ANSSymbolInfo& i2) {
return (i1.symbol < i2.symbol);
}
WP2Status SymbolReader::Init(const SymbolsInfo& symbols_info,
ANSDec* const dec) {
dec_ = dec;
// Get an upper bound on the number of symbols that might use a dictionary.
uint32_t num_auto = 0u;
for (uint32_t s = 0; s < symbols_info.Size(); ++s) {
if (symbols_info.Method(s) == SymbolsInfo::StorageMethod::kAuto) {
num_auto += symbols_info.NumClusters(s);
}
}
WP2_CHECK_ALLOC_OK(codes_.reserve(num_auto));
return SymbolIO<StatExtra>::Init(symbols_info);
}
void SymbolReader::AddTrivial(uint32_t cluster, uint32_t sym, int32_t value) {
Stat* const stat = GetStats(cluster, sym);
stat->type = Stat::Type::kTrivial;
stat->param.trivial_value = value;
}
void SymbolReader::AddRange(uint32_t cluster, uint32_t sym,
const VectorNoCtor<ANSSymbolInfo>* const infos,
uint16_t max_range) {
Stat* const stat = GetStats(cluster, sym);
stat->type = Stat::Type::kRange;
if (infos != nullptr) {
stat->use_mapping = true;
for (uint32_t k = 0; k < infos->size(); ++k) {
stat->mappings[k] = (*infos)[k].symbol;
}
stat->range = infos->size();
} else {
stat->use_mapping = false;
stat->range = max_range;
}
}
WP2Status SymbolReader::AddDict(uint32_t cluster, uint32_t sym,
VectorNoCtor<ANSSymbolInfo>* const infos) {
Stat* const stat = GetStats(cluster, sym);
assert(infos->size() > 1);
stat->type = Stat::Type::kDict;
stat->use_mapping = true;
// Make sure the symbols are ordered like in the encoding.
std::sort(infos->begin(), infos->end(), ANSSymbolInfoCmp);
// Create the mappings.
Vector_u32 counts;
WP2_CHECK_ALLOC_OK(counts.resize(infos->size()));
for (uint32_t k = 0; k < counts.size(); ++k) {
stat->mappings[k] = (*infos)[k].symbol;
counts[k] = (*infos)[k].freq;
}
assert(infos->size() <= ANS_MAX_SYMBOLS);
// Convert the histogram to a spread table.
WP2_CHECK_ALLOC_OK(codes_.resize(codes_.size() + 1));
stat->extra.log2_tab_size = ANS_LOG_TAB_SIZE;
WP2_CHECK_STATUS(ANSCountsToSpreadTable(
&counts[0], counts.size(), stat->extra.log2_tab_size,
codes_.back())); // will allocate codes_.back()
stat->extra.codes = codes_.back().data();
return WP2_STATUS_OK;
}
WP2Status SymbolReader::AddGolomb(uint32_t cluster, uint32_t sym,
VectorNoCtor<WP2::ANSSymbolInfo>* const infos,
uint32_t prefix_size) {
WP2_CHECK_STATUS(AddDict(cluster, sym, infos));
SetGolombStat(cluster, sym, prefix_size);
return WP2_STATUS_OK;
}
int32_t SymbolReader::Read(uint32_t cluster, uint32_t sym, WP2_OPT_LABEL,
double* const cost) {
int32_t value;
const WP2Status status = ReadInternal(cluster, sym, /*use_max_value=*/false,
/*max_value=*/0, label, &value, cost);
(void)status;
assert(status == WP2_STATUS_OK);
// TODO(yguyon): Enforce dec_->GetStatus() checking more widely.
// Currently some failures are silent.
return value;
}
WP2Status SymbolReader::TryRead(uint32_t cluster, uint32_t sym,
uint32_t max_value, WP2_OPT_LABEL,
int32_t* const value, double* const cost) {
const WP2Status status = ReadInternal(cluster, sym, /*use_max_value=*/true,
max_value, label, value, cost);
if (dec_->GetStatus() != WP2_STATUS_OK) return dec_->GetStatus();
return status;
}
uint32_t SymbolReader::FindSymbol(const Stat& stat, uint32_t max_value) {
if (stat.mappings[stat.extra.codes[0].symbol] > max_value) {
// If the first symbol is already bigger than the max_value, stop here.
return 0;
}
uint32_t i_sup = (1u << stat.extra.log2_tab_size) - 1;
i_sup -= stat.extra.codes[i_sup].offset;
if (stat.mappings[stat.extra.codes[i_sup].symbol] <= max_value) {
// Stop if the last interval is actually the right one.
return i_sup;
}
// Use a binary search to find i_inf such that:
// stat.mappings[symbol at i_inf] <= max_value
// and stat.mappings[next symbol after the one at i_inf] > max_value
uint32_t i_inf = 0;
while (true) {
const uint32_t i_inf_next = i_inf + stat.extra.codes[i_inf].freq;
// Stop if we reached a good interval.
if (stat.mappings[stat.extra.codes[i_inf_next].symbol] > max_value) {
break;
}
i_inf = i_inf_next;
uint32_t i_mid = (i_inf + i_sup) / 2;
const ANSSymbolInfo& info_mid = stat.extra.codes[i_mid];
i_mid -= info_mid.offset;
if (stat.mappings[info_mid.symbol] > max_value) {
i_sup = i_mid;
} else {
i_inf = i_mid;
}
}
// Make sure the interval is valid: check its symbol.
const ANSSymbolInfo& info_inf = stat.extra.codes[i_inf];
assert(stat.mappings[info_inf.symbol] <= max_value);
// Check if we are the last interval.
if (i_inf + info_inf.freq == (1u << stat.extra.log2_tab_size)) return i_inf;
// If not, check the symbol of the next interval.
const ANSSymbolInfo& info_inf_next = stat.extra.codes[i_inf + info_inf.freq];
if (stat.mappings[info_inf_next.symbol] <= max_value) assert(false);
return i_inf;
}
static ANSSymbolInfo GetInfo(const ANSSymbolInfo* const codes, uint32_t value) {
uint32_t i = 0;
while (codes[i].symbol < value &&
i + codes[i].freq < (1 << ANS_LOG_TAB_SIZE)) {
i += codes[i].freq;
}
ANSSymbolInfo info;
info.offset = i;
info.freq = codes[i].freq;
info.symbol = codes[i].symbol;
return info;
}
// Returns the index of 'mapped' in the 'mappings' array. If 'mapped' is not
// in the array, the index of the first value below it is returned.
static uint16_t Unmap(const uint16_t* const mappings, uint32_t mapping_size,
uint16_t mapped) {
if (mappings[0] > mapped) return 0;
uint32_t i_sup = mapping_size - 1;
if (mappings[i_sup] <= mapped) return i_sup;
// Use a binary search to find i_inf such that:
// mappings[i_inf] <= max_value
// and mappings[i_inf + 1] > max_value
uint32_t i_inf = 0;
while (true) {
// Stop if we reached a good interval.
if (mappings[i_inf + 1] > mapped) {
break;
}
++i_inf;
uint32_t i_mid = (i_inf + i_sup) / 2;
if (mappings[i_mid] > mapped) {
i_sup = i_mid;
} else {
i_inf = i_mid;
}
}
assert(mappings[i_inf] <= mapped);
assert(i_inf == (mapping_size - 1) || mappings[i_inf + 1] > mapped);
return i_inf;
}
// Returns the frequency value for symbol "value", scaled to take into account
// that the max possible value is "max_value". Both "value" and "max_value"
// should be raw values (not mapped).
static uint32_t ScaleFreq(const ANSSymbolInfo* codes, uint32_t log2_tab_size,
uint32_t value, uint32_t max_value) {
const uint32_t tab_size = (1u << log2_tab_size);
const ANSSymbolInfo& info = GetInfo(codes, value);
const ANSSymbolInfo& info_max = GetInfo(codes, max_value);
const uint32_t offset =
info.offset * tab_size / (info_max.offset + info_max.freq);
const uint32_t offset_next =
(info.offset + info.freq) * tab_size / (info_max.offset + info_max.freq);
return offset_next - offset;
}
WP2Status SymbolReader::ReadInternal(uint32_t cluster, uint32_t sym,
bool use_max_value, uint32_t max_value,
WP2_OPT_LABEL, int32_t* const value,
double* const cost) {
// Consider reading a symbol as a single read occurrence.
dec_->PushBitTracesCustomPrefix(label, /*merge_until_pop=*/true);
const Stat& stat = *GetStats(cluster, sym);
if (use_max_value && stat.use_mapping) {
// Make sure at least one value in the mapping is inferior to the max_value.
// We check the minimal/first one.
// TODO(vrabaud) could be optimized if at encoding time, every time we hit
// the minimal stat.mapping values, they are max_value. We
// would then remove them from stat.mapping values and let
// max_value be used instead.
WP2_CHECK_OK(stat.mappings[0] <= max_value, WP2_STATUS_BITSTREAM_ERROR);
}
ANSDebugPrefix debug_prefix(dec_, label);
switch (stat.type) {
case Stat::Type::kTrivial:
*value = stat.param.trivial_value;
if (use_max_value) {
// TODO(vrabaud) could be optimized if at encoding time we only have one
// value and hit max_value. pessimization: 0.50%.
WP2_CHECK_OK(*value <= (int32_t)max_value, WP2_STATUS_BITSTREAM_ERROR);
}
break;
case Stat::Type::kRange: {
uint32_t range;
if (!use_max_value) {
range = stat.range;
} else if (stat.use_mapping) {
// Find the biggest range such that mapping[range] <= max_value.
range = std::upper_bound(stat.mappings, stat.mappings + stat.range,
max_value) -
stat.mappings;
assert(range == stat.range || stat.mappings[range] > max_value);
assert(stat.mappings[range - 1] <= max_value);
} else {
range = std::min(stat.range, (uint16_t)(max_value + 1));
}
*value = dec_->ReadRValue(range, "range");
if (cost != nullptr) *cost += std::log2(range);
if (stat.use_mapping) *value = stat.mappings[*value];
break;
}
case Stat::Type::kDict: {
uint32_t raw_value;
if (use_max_value) {
const uint32_t max_index = FindSymbol(stat, max_value);
raw_value = dec_->ReadSymbol(stat.extra.codes, stat.extra.log2_tab_size,
max_index, "dict");
if (cost != nullptr) {
const uint32_t mapping_size =
stat.extra.codes[(1 << stat.extra.log2_tab_size) - 1].symbol + 1;
const uint32_t scaled_freq =
ScaleFreq(stat.extra.codes, stat.extra.log2_tab_size, raw_value,
Unmap(stat.mappings, mapping_size, max_value));
*cost -= std::log2(scaled_freq) - stat.extra.log2_tab_size;
}
} else {
raw_value = dec_->ReadSymbol(&stat.extra.codes[0],
stat.extra.log2_tab_size, "dict");
if (cost != nullptr) {
const uint16_t freq = GetInfo(stat.extra.codes, raw_value).freq;
*cost -= std::log2(freq) - stat.extra.log2_tab_size;
}
}
assert(stat.use_mapping);
*value = stat.mappings[raw_value];
break;
}
case Stat::Type::kGolomb: {
const uint32_t prefix_size = stat.param.golomb.prefix_size;
uint32_t prefix;
bool use_range;
uint32_t range = 0;
if (use_max_value) {
const Golomb golomb_max(max_value, prefix_size);
const uint32_t i_inf = FindSymbol(stat, golomb_max.prefix);
const uint32_t raw_prefix = dec_->ReadSymbol(
&stat.extra.codes[0], stat.extra.log2_tab_size, i_inf, "golomb");
prefix =
std::min(golomb_max.prefix, (uint32_t)stat.mappings[raw_prefix]);
const uint32_t max_prefix =
std::min(golomb_max.prefix,
(uint32_t)stat.mappings[stat.extra.codes[i_inf].symbol]);
// Use ranges if we are at the last interval.
use_range = (prefix == max_prefix);
if (use_range) {
// Get (the biggest value with a prefix of 'max_prefix') + 1.
range = std::min(
max_value + 1,
Golomb::Merge(max_prefix, prefix_size, 0) +
(1 << Golomb::NumExtraBits(max_prefix, prefix_size)));
}
if (cost != nullptr) {
const uint32_t scaled_freq =
ScaleFreq(stat.extra.codes, stat.extra.log2_tab_size, raw_prefix,
golomb_max.prefix);
*cost -= std::log2(scaled_freq) - stat.extra.log2_tab_size;
}
} else {
const uint32_t raw_prefix = dec_->ReadSymbol(
&stat.extra.codes[0], stat.extra.log2_tab_size, "golomb");
prefix = stat.mappings[raw_prefix];
// Use ranges if we are at the last interval.
use_range = (prefix + 1 == stat.param.golomb.range);
if (use_range) range = stat.range;
if (cost != nullptr) {
const uint16_t freq = GetInfo(stat.extra.codes, raw_prefix).freq;
*cost -= std::log2(freq) - stat.extra.log2_tab_size;
}
}
const uint32_t extra_bits = Golomb::NumExtraBits(prefix, prefix_size);
uint32_t extra_bits_value;
if (extra_bits == 0) {
extra_bits_value = 0;
} else {
if (use_range) {
assert(Golomb::Merge(prefix, prefix_size, 0) < range);
range -= Golomb::Merge(prefix, prefix_size, 0);
extra_bits_value = dec_->ReadRValue(range, "extra_bits_value");
if (cost != nullptr) {
*cost += std::log2(range);
}
} else {
extra_bits_value = dec_->ReadUValue(extra_bits, "extra_bits_value");
if (cost != nullptr) {
*cost += extra_bits;
}
}
}
*value = Golomb::Merge(prefix, prefix_size, extra_bits_value);
break;
}
case Stat::Type::kAdaptiveBit: {
if (use_max_value && max_value == 0) {
*value = 0;
} else {
*value =
dec_->ReadBit(a_bits_[stat.param.a_bit_index].Proba(), "a_bit");
if (cost != nullptr) {
*cost += a_bits_[stat.param.a_bit_index].GetCost(*value);
}
}
a_bits_[stat.param.a_bit_index].Update(*value);
break;
}
case Stat::Type::kAdaptiveSymbol: {
ANSAdaptiveSymbol& asym = a_symbols_[stat.param.a_symbol_index];
if (use_max_value) {
*value = dec_->ReadSymbol(asym, max_value, "a_symbol");
if (cost != nullptr) {
*cost += asym.GetCost(*value, max_value);
}
} else {
*value = dec_->ReadSymbol(asym, "a_symbol");
if (cost != nullptr) {
*cost += asym.GetCost(*value);
}
}
asym.Update(*value);
break;
}
default:
// Did you forget to call ReadHeader for this symbol?
assert(false);
*value = 0;
dec_->PopBitTracesCustomPrefix(label);
return WP2_STATUS_BITSTREAM_ERROR;
}
if (*value != 0 && stat.type != Stat::Type::kTrivial &&
symbols_info_.Min(cluster, sym) < 0) {
// TODO(vrabaud) implement other methods to store negative values.
if (dec_->ReadBool("is_negative")) *value = -(*value);
if (cost != nullptr) *cost += 1.;
}
if (use_max_value) assert((uint32_t)std::abs(*value) <= max_value);
dec_->PopBitTracesCustomPrefix(label);
return WP2_STATUS_OK;
}
WP2Status SymbolReader::ReadHeader(uint32_t max_nnz, uint32_t sym,
WP2_OPT_LABEL) {
for (uint32_t cluster = 0; cluster < symbols_info_.NumClusters(sym);
++cluster) {
WP2_CHECK_STATUS(ReadHeader(cluster, max_nnz, sym, label));
}
return dec_->GetStatus();
}
WP2Status SymbolReader::ReadHeader(uint32_t cluster, uint32_t max_nnz,
uint32_t sym, WP2_OPT_LABEL) {
if (symbols_info_.Method(sym) == SymbolsInfo::StorageMethod::kUnused) {
return WP2_STATUS_OK;
}
const uint32_t range = symbols_info_.Range(cluster, sym);
if (range < 2) {
assert(range > 0);
AddTrivial(cluster, sym, 0);
return WP2_STATUS_OK;
}
ANSDebugPrefix prefix(dec_, label);
switch (symbols_info_.Method(sym)) {
case SymbolsInfo::StorageMethod::kAuto:
// This main case is handled after.
break;
case SymbolsInfo::StorageMethod::kAdaptiveBit:
WP2_CHECK_STATUS(AddAdaptiveBit(
cluster, sym, symbols_info_.StartingProbaP0(cluster, sym),
symbols_info_.StartingProbaP1(cluster, sym)));
return dec_->GetStatus();
case SymbolsInfo::StorageMethod::kAdaptiveSym: {
WP2_CHECK_STATUS(AddAdaptiveSymbol(cluster, sym,
ANSAdaptiveSymbol::Method::kAOM,
kANSAProbaInvalidSpeed));
return dec_->GetStatus();
}
case SymbolsInfo::StorageMethod::kAdaptiveWithAutoSpeed: {
if (max_nnz <= 1) {
// Adaptation method/speed don't matter if we only use this symbol once.
return AddAdaptiveSymbol(cluster, sym,
ANSAdaptiveSymbol::Method::kConstant, 0);
}
const auto method = (ANSAdaptiveSymbol::Method)dec_->ReadRValue(
(uint32_t)ANSAdaptiveSymbol::Method::kNum, "adaptation_method");
uint32_t speed;
if (method == ANSAdaptiveSymbol::Method::kConstant) {
speed = kAdaptationSpeeds[dec_->ReadRValue(kNumAdaptationSpeeds,
"adaptation_speed")];
} else {
speed = kANSAProbaInvalidSpeed;
}
WP2_CHECK_STATUS(AddAdaptiveSymbol(cluster, sym, method, speed));
return dec_->GetStatus();
}
case SymbolsInfo::StorageMethod::kUnused:
break;
}
const uint32_t nnz_range = std::min(max_nnz, range);
const SymbolCount c = (SymbolCount)dec_->ReadRValue(
(nnz_range == 1) ? kSymbolCountLast - 1 : kSymbolCountLast, "scount");
// Deal with trivial cases.
if (c == kSymbolCountZero) {
AddTrivial(cluster, sym, 0);
return dec_->GetStatus();
}
if (c == kSymbolCountOne) {
const int32_t value =
dec_->ReadRange(symbols_info_.Min(cluster, sym),
symbols_info_.Max(cluster, sym), "symbol");
AddTrivial(cluster, sym, value);
return dec_->GetStatus();
}
const uint32_t type = dec_->ReadRValue(3, "type");
if (type == 2) { // kGolomb
const uint32_t golomb_prefix_size = dec_->ReadRange(1, 2, "prefix_size_m1");
const Golomb golomb(range - 1, golomb_prefix_size);
const uint32_t range_golomb = golomb.prefix + 1;
const uint32_t nnz =
dec_->ReadRange(2, std::min(max_nnz, range_golomb), "size_m2");
WP2_CHECK_STATUS(
ReadHistogram(nnz, range_golomb, max_nnz, dec_, infos_));
WP2_CHECK_STATUS(AddGolomb(cluster, sym, &infos_, golomb_prefix_size));
} else if (type == 1) { // kDict
// a kDict can't use more than ANS_MAX_SYMBOLS
const uint32_t dict_range = std::min(nnz_range, (uint32_t)ANS_MAX_SYMBOLS);
const uint32_t nnz = dec_->ReadRange(2, dict_range, "size_m2");
WP2_CHECK_STATUS(ReadHistogram(nnz, range, max_nnz, dec_, infos_));
WP2_CHECK_STATUS(AddDict(cluster, sym, &infos_));
} else { // kRange
const uint32_t nnz = dec_->ReadRange(1, nnz_range, "size_m1");
WP2_CHECK_STATUS(LoadMapping(dec_, nnz, range, mapping_));
// Store everything in the info.
WP2_CHECK_ALLOC_OK(infos_.resize(nnz));
for (uint32_t k = 0; k < infos_.size(); ++k) {
infos_[k].symbol = mapping_[k];
infos_[k].freq = 1;
}
AddRange(cluster, sym, &infos_, range - nnz);
}
return dec_->GetStatus();
}
void SymbolReader::GetPotentialUsage(uint32_t cluster, uint32_t sym,
bool is_maybe_used[],
uint32_t size) const {
const Stat& stat = *GetStats(cluster, sym);
switch (stat.type) {
case (Stat::Type::kTrivial):
// Nothing is used but the one value.
std::fill(is_maybe_used, is_maybe_used + size, false);
is_maybe_used[stat.param.trivial_value] = true;
break;
case (Stat::Type::kRange):
// We have no idea of what is used or not.
std::fill(is_maybe_used, is_maybe_used + size, true);
break;
case (Stat::Type::kDict):
// Go over the stats to figure out what is used or not.
if (stat.use_mapping) {
std::fill(is_maybe_used, is_maybe_used + size, false);
}
for (uint32_t k = 0; k < (1u << stat.extra.log2_tab_size);
k += stat.extra.codes[k].freq) {
is_maybe_used[stat.use_mapping
? stat.mappings[stat.extra.codes[k].symbol]
: stat.extra.codes[k].symbol] = true;
}
break;
case Stat::Type::kGolomb: {
// Go over the stats to figure out what is used or not.
if (stat.use_mapping) {
std::fill(is_maybe_used, is_maybe_used + size, false);
}
const uint32_t prefix_size = stat.param.golomb.prefix_size;
for (uint32_t k = 0; k < (1u << stat.extra.log2_tab_size);
k += stat.extra.codes[k].freq) {
const uint32_t m = stat.use_mapping
? stat.mappings[stat.extra.codes[k].symbol]
: stat.extra.codes[k].symbol;
const uint32_t extra_bits_num = Golomb::NumExtraBits(m, prefix_size);
const uint32_t m1 = Golomb::Merge(m, prefix_size, 0);
const uint32_t m2 =
Golomb::Merge(m, prefix_size, (1 << extra_bits_num) - 1);
std::fill(is_maybe_used + m1, is_maybe_used + std::min(m2 + 1, size),
true);
}
break;
}
default:
assert(false);
break;
}
}
} // namespace WP2