| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ----------------------------------------------------------------------------- |
| // |
| // Asymmetric Numeral System helper functions. |
| // |
| // Author: Skal (pascal.massimino@gmail.com) |
| |
| #include "src/utils/ans_utils.h" |
| |
| #include <algorithm> |
| #include <array> |
| #include <cassert> |
| #include <cstddef> |
| #include <cstdint> |
| #include <cstring> |
| #include <numeric> |
| #include <optional> |
| |
| #include "src/dsp/math.h" |
| #include "src/utils/ans.h" |
| #include "src/utils/ans_enc.h" |
| #include "src/wp2/base.h" |
| |
| namespace WP2 { |
| |
| //------------------------------------------------------------------------------ |
| // ANS quantizations. |
| |
| void ANSCountsQuantizeHuffman(uint32_t size, const uint32_t* const counts, |
| uint32_t* const out, uint32_t max_bits, |
| float* const cost) { |
| const uint32_t current_max_value = *std::max_element(counts, counts + size); |
| const uint32_t max_value = std::min(1u << (max_bits - 1), current_max_value); |
| |
| uint32_t sum = 0, sum_q = 0; |
| const float norm = 1.f * max_value / current_max_value; |
| *cost = 0.f; |
| for (size_t c = 0; c < size; ++c) { |
| if (counts[c] > 0) { |
| const float scaled = std::max(counts[c] * norm, 1.f); |
| uint32_t num_bits = WP2Log2Floor((uint32_t)scaled); |
| uint32_t huffmanized = (1u << num_bits); |
| if ((huffmanized << 1) - scaled < scaled - huffmanized) { |
| ++num_bits; |
| huffmanized <<= 1; |
| } |
| out[c] = huffmanized; |
| *cost -= counts[c] * num_bits; |
| sum += counts[c]; |
| sum_q += huffmanized; |
| } else { |
| out[c] = 0; |
| } |
| } |
| |
| *cost += sum * WP2Log2(sum_q); |
| } |
| |
| bool ANSCountsQuantize(bool do_expand, uint32_t max_freq, uint32_t size, |
| uint32_t* const counts, float* const cost) { |
| return ANSCountsQuantize(do_expand, max_freq, size, counts, counts, cost); |
| } |
| |
| bool ANSCountsQuantize(bool do_expand, uint32_t max_freq, uint32_t size, |
| const uint32_t* const counts, uint32_t* const out, |
| float* const cost) { |
| assert(max_freq > 0); |
| if (size == 0) return false; |
| const uint32_t max_value = *std::max_element(counts, counts + size); |
| if (max_value == 0 || max_freq == max_value || |
| (!do_expand && max_value <= max_freq)) { |
| if (out != counts) memcpy(out, counts, size * sizeof(uint32_t)); |
| if (cost != nullptr) *cost = ANSCountsCost(counts, size); |
| return false; |
| } |
| uint32_t sum = 0, sum_q = 0; |
| if (cost != nullptr) *cost = 0; |
| const float norm = 1. * max_freq / max_value; |
| for (size_t c = 0; c < size; ++c) { |
| if (counts[c] > 0) { |
| const int new_count = (int)(counts[c] * norm + .5); |
| out[c] = (new_count < 1) ? 1 : new_count; |
| if (cost != nullptr) { |
| *cost -= counts[c] * WP2Log2(out[c]); |
| sum += counts[c]; |
| sum_q += out[c]; |
| } |
| } else { |
| out[c] = 0; |
| } |
| } |
| if (cost != nullptr) { |
| const float cost_sum = 0.f - sum * WP2Log2(sum_q); |
| if (*cost == cost_sum) { |
| // Remove rounding errors when only one count is non-zero. |
| *cost = 0.f; |
| } else { |
| *cost -= cost_sum; |
| } |
| } |
| return true; |
| } |
| |
| // The bit cost of an element "i" is: |
| // log2(1/proba[i]) = log2(sum(counts)/counts[i]) |
| float ANSCountsCost(const uint32_t* counts, uint32_t size) { |
| return ShannonEntropy(counts, size); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // ANS vector storage. |
| |
| // Compute the bit cost when merging two stats. |
| static inline uint32_t ComputeCost(const OptimizeArrayStorageStat* const s1, |
| const OptimizeArrayStorageStat* const s2) { |
| const int bit_diff = (s1->n_bits - s2->n_bits); |
| int total_diff = bit_diff * ((bit_diff > 0) ? s2->count : -s1->count); |
| // If there was only one value, by merging we lose the optimization that we |
| // don't need to store the high order bit. |
| // TODO(maryla): This is approximate. In the case where n_bits == n_bits_max |
| // (when values are stored with an RValue instead of a UValue), we potentially |
| // win a lot more than one bit. |
| if (s1->count == 1) ++total_diff; |
| if (s2->count == 1) ++total_diff; |
| return total_diff; |
| } |
| |
| // 'overhead' is the size in bits of: size in bits to store the number of |
| // elements and their size in bits. |
| // The list of stats is simplified to represent the same number of elements but |
| // with a smaller overall bit cost. |
| void OptimizeArrayStorage(OptimizeArrayStorageStat* const stats, |
| size_t* const size_in, float overhead) { |
| // TODO(vrabaud) include an effort parameter to choose between the two methods |
| // or at least, choose the slower one if size is small (e.g. < |
| // 5). |
| // The first method gives a speedup of an order of magnitude for a compression |
| // hit of less than 0.1% |
| size_t size = *size_in; |
| #if OPTIMIZE_ARRAY_STORAGE |
| while (size > 1) { |
| auto s2 = stats + 1, s_end = stats + size; |
| // Find the first possible merge, if any. |
| while (s2 < s_end && ComputeCost(s2 - 1, s2) >= overhead) ++s2; |
| |
| auto s1 = s2 - 1; |
| for (; s2 < s_end; ++s2) { |
| if (ComputeCost(s1, s2) < overhead) { |
| // Merge. |
| s1->count = s2->count + s1->count; |
| s1->n_bits = std::max(s2->n_bits, s1->n_bits); |
| } else { |
| ++s1; |
| *s1 = *s2; |
| } |
| } |
| const size_t size_new = s1 - stats + 1; |
| if (size_new == size) break; |
| size = size_new; |
| } |
| #else |
| // TODO(vrabaud) overhead actually decreases as the number of elements becomes |
| // less and less: take that into account. |
| // Fill the costs once and for all. |
| for (size_t i = 1; i < size; ++i) { |
| stats[i].cost_ = ComputeCost(&stats[i - 1], &stats[i]); |
| } |
| |
| size_t start = 1, end = 1; |
| while (start < size) { |
| // Find the first element that could be optimized. |
| while (start < size && stats[start].cost_ >= overhead) { |
| ++start; |
| } |
| if (start == size) break; |
| |
| // Find the first following element that cannot be optimized (or the end). |
| if (end <= start) { |
| end = start + 1; |
| while (end < size && stats[end].cost_ < overhead) { |
| ++end; |
| } |
| } |
| |
| do { |
| size_t i = start, i_max = i; |
| // Find the minimal cost. |
| uint32_t max_cost = stats[i].cost_; |
| for (++i; i < end; ++i) { |
| if (stats[i].cost_ < max_cost) { |
| max_cost = stats[i].cost_; |
| i_max = i; |
| } |
| } |
| if (max_cost < overhead) break; |
| // Only continue if we were able to have a good enough cost. |
| auto max_s2 = stats + i_max, max_s1 = max_s2 - 1; |
| max_s1->count_ = max_s1->count_ + max_s2->count_; |
| max_s1->n_bits_ = std::max(max_s1->n_bits_, max_s2->n_bits_); |
| if (max_s1 != stats) max_s1->cost_ = ComputeCost(max_s1 - 1, max_s1); |
| // Remove max_s2. |
| memmove(stats + i_max, stats + i_max + 1, |
| (size - i_max - 1) * sizeof(OptimizeArrayStorageStat)); |
| --size; |
| --end; |
| if (i_max != size) { |
| stats[i_max].cost_ = ComputeCost(stats + i_max - 1, stats + i_max); |
| } |
| } while (start < end); |
| } |
| #endif |
| *size_in = size; |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // ANSEncCounter |
| |
| float ANSEncCounter::GetCost() const { return cost_ + symbol_cost_; } |
| |
| float ANSEncCounter::GetCost(const ANSDictionaries& dicts) const { |
| float cost = cost_; |
| for (const auto& d : dicts) { |
| if (d != nullptr) cost += d->Cost(); |
| } |
| return cost; |
| } |
| |
| uint32_t ANSEncCounter::PutBit(uint32_t bit, uint32_t num_zeros, |
| uint32_t num_total, WP2_OPT_LABEL) { |
| cost_ += |
| WP2Log2(num_total) - WP2Log2(bit ? num_total - num_zeros : num_zeros); |
| return bit; |
| } |
| |
| uint32_t ANSEncCounter::PutABit(uint32_t bit, ANSBinSymbol* const stats, |
| WP2_OPT_LABEL) { |
| ANSEncCounter::PutBit(bit, stats->NumZeros(), stats->NumTotal(), |
| label); // will update cost_ |
| stats->Update(bit); |
| return bit; |
| } |
| |
| uint32_t ANSEncCounter::PutSymbol(uint32_t symbol, const ANSDictionary& dict, |
| WP2_OPT_LABEL) { |
| symbol_cost_ += dict.SymbolCost(symbol); |
| return symbol; |
| } |
| |
| uint32_t ANSEncCounter::PutSymbol(uint32_t symbol, |
| const ANSAdaptiveSymbol& asym, |
| WP2_OPT_LABEL) { |
| cost_ += asym.GetCost(symbol); |
| return symbol; |
| } |
| |
| uint32_t ANSEncCounter::PutSymbol(uint32_t symbol, uint32_t max_symbol, |
| const ANSDictionary& dict, WP2_OPT_LABEL) { |
| symbol_cost_ += dict.SymbolCost(symbol, max_symbol); |
| return symbol; |
| } |
| |
| uint32_t ANSEncCounter::PutSymbol(uint32_t symbol, uint32_t max_symbol, |
| const ANSAdaptiveSymbol& asym, |
| WP2_OPT_LABEL) { |
| cost_ += asym.GetCost(symbol, max_symbol); |
| return symbol; |
| } |
| |
| uint32_t ANSEncCounter::PutUValue(uint32_t value, uint32_t bits, |
| WP2_OPT_LABEL) { |
| cost_ += bits; |
| return value; |
| } |
| |
| uint32_t ANSEncCounter::PutRValue(uint32_t value, uint32_t range, |
| WP2_OPT_LABEL) { |
| cost_ += WP2Log2(range); |
| return value; |
| } |
| |
| WP2Status ANSEncCounter::Append(const ANSEncCounter& enc) { |
| cost_ += enc.cost_; |
| return WP2_STATUS_OK; |
| } |
| |
| void ANSEncCounter::Reset() { |
| cost_ = 0; |
| symbol_cost_ = 0; |
| } |
| |
| //------------------------------------------------------------------------------ |
| // Mapping utils. |
| |
| enum class MappingMethod { |
| kRange = 0, |
| kStoreVector, |
| kAdaptativeBit, |
| kBit, |
| }; |
| static constexpr uint32_t kNumMapping = 4; |
| |
| WP2Status LoadMapping(ANSDec* const dec, uint32_t size, uint32_t range, |
| uint16_t* const mapping) { |
| assert(size <= range); |
| assert(mapping != nullptr); |
| if (size == 0) return WP2_STATUS_OK; |
| ANSDebugPrefix prefix(dec, "mapping"); |
| // TODO(vrabaud) Have it be a ReadBit with learned probability. |
| if (size == range || dec->ReadBool("is_series")) { |
| std::iota(mapping, mapping + size, 0); |
| return WP2_STATUS_OK; |
| } |
| |
| const auto method = (MappingMethod)dec->ReadRValue(kNumMapping, "method"); |
| switch (method) { |
| case MappingMethod::kRange: { |
| // Read the runs, alternating between the beginning and end ones. |
| uint32_t i = 0, j = size; |
| for (uint32_t num_left = size; num_left > 0; ++i) { |
| if (i == 0) { |
| mapping[i] = dec->ReadRange(0, range - num_left, "run_beg"); |
| } else { |
| mapping[i] = dec->ReadRange(mapping[i - 1] + 1, mapping[j] - num_left, |
| "run_beg"); |
| } |
| --num_left; |
| if (num_left > 0) { |
| --j; |
| mapping[j] = dec->ReadRange(mapping[i] + num_left, |
| (i == 0) ? range - 1 : mapping[j + 1] - 1, |
| "run_end"); |
| --num_left; |
| // early out to avoid no-op calls to ReadRange() |
| if (mapping[i] + num_left + 1 == mapping[j]) break; |
| } |
| } |
| while (++i < j) mapping[i] = mapping[i - 1] + 1; |
| break; |
| } |
| case MappingMethod::kAdaptativeBit: { |
| uint32_t num_zeros = range - size; |
| for (uint32_t i = 0, j = 0; i < size; ++j) { |
| const uint32_t num_left = range - j; |
| if (num_zeros == 0 || dec->ReadBit(num_zeros, num_left, "is_used")) { |
| mapping[i] = j; |
| ++i; |
| } else { |
| --num_zeros; |
| } |
| } |
| break; |
| } |
| case MappingMethod::kBit: { |
| uint32_t num_zeros = range - size; |
| for (uint32_t i = 0, j = 0; i < size; ++j) { |
| if (num_zeros == 0 || dec->ReadBool("is_used")) { |
| mapping[i] = j; |
| ++i; |
| } else { |
| --num_zeros; |
| } |
| } |
| break; |
| } |
| case MappingMethod::kStoreVector: { |
| size_t k = 0; |
| uint32_t n_bits_prev = 0; |
| |
| // 'range_left' is the range of the remaining runs, hence +1. |
| uint32_t range_left = range - size + 1; |
| uint32_t size_left = size; |
| while (size_left > 0 && range_left > 1) { |
| const uint32_t n_bits_max = FindLastSet(range_left - 1); |
| // Figure out the number of bits. |
| uint32_t n_bits; |
| // The number of bits is in [0:n_bits_max] for the first set or when the |
| // previous value was out of the current range. |
| if (k == 0 || n_bits_prev > n_bits_max) { |
| n_bits = dec->ReadRValue(n_bits_max + 1, "n_bits"); |
| } else { |
| // If we know n_bits_prev is in [0:n_bits_max], we can save one bit |
| // as we know the value has to be different from before. |
| n_bits = dec->ReadRValue(n_bits_max, "n_bits"); |
| if (n_bits >= n_bits_prev) ++n_bits; |
| } |
| n_bits_prev = n_bits; |
| |
| // Figure out the number of elements with that bit depth. |
| const uint32_t n = dec->ReadRange(1, size_left, "count"); |
| if (n_bits == 0) { |
| // For a depth of 0 bits, we know it is a 0. |
| for (size_t i = 0; i < n; ++i) { |
| mapping[k] = (k == 0) ? 0 : mapping[k - 1] + 1; |
| ++k; |
| } |
| } else if (n == 1) { |
| // If we only have one number, we didn't store the high order |
| // bit as it's always 1, otherwise it would fit in (n_bits_ - 1). |
| uint32_t delta; |
| if (n_bits == 1) { |
| // And if that number is only on one bit to begin with, we didn't |
| // store anything. |
| delta = 1u; |
| } else { |
| delta = 1u << (n_bits - 1); |
| assert(range_left >= delta); |
| // After this value, the range will be at most range_left - delta. |
| // If it is smaller than the number of bits we use, we can just use |
| // a range. |
| if (range_left - delta < delta) { |
| delta |= dec->ReadRValue(range_left - delta, "value"); |
| } else { |
| delta |= dec->ReadUValue(n_bits - 1, "value"); |
| } |
| } |
| mapping[k] = (k == 0) ? delta : mapping[k - 1] + delta + 1; |
| ++k; |
| assert(range_left >= delta); |
| range_left -= delta; |
| } else { |
| for (size_t i = 0; i < n; ++i) { |
| uint32_t value; |
| if (range_left < (1u << n_bits)) { |
| value = dec->ReadRValue(range_left, "value"); |
| } else { |
| value = dec->ReadUValue(n_bits, "value"); |
| } |
| mapping[k] = (k == 0) ? value : mapping[k - 1] + value + 1; |
| ++k; |
| assert(range_left >= value); |
| range_left -= value; |
| // Stop here if everything else is in [0, 1) hence is 0. |
| if (range_left == 1) break; |
| } |
| } |
| assert(size_left >= n); |
| size_left -= n; |
| } |
| |
| // If we have not filled 'mapping' yet, it means we only have 0's left |
| // for the runs. |
| assert(k > 0); |
| std::iota(&mapping[k], &mapping[size], mapping[k - 1] + 1); |
| break; |
| } |
| } |
| return WP2_STATUS_OK; |
| } |
| |
| // Returns the effective reduced range of mapping values that are not trivial. |
| static uint32_t GetEffectiveRange(const uint16_t* const mapping, uint32_t size, |
| uint32_t range) { |
| uint32_t j = 0; |
| for (uint32_t i = 0; i < size; ++i) { |
| j = mapping[i] + 1; |
| if (j > range - size + i) return range - size + i; |
| } |
| return j; |
| } |
| |
| // Helper function that is a copy/paste of StoreVector with a few tweaked |
| // values. |
| static void StoreVectorForMapping(const uint16_t* const mapping, size_t size, |
| uint32_t range, |
| const OptimizeArrayStorageStat* const stats, |
| uint32_t stats_size, |
| std::optional<float> cost_max, |
| ANSEncBase* const enc) { |
| assert(size <= range); |
| uint32_t n_bits_prev = 0; |
| size_t size_left = size; |
| uint32_t range_left = range - size + 1; |
| for (uint32_t ind = 0, i = 0; i < stats_size; ++i) { |
| assert(range_left > 1); |
| assert(ind < size); |
| // We need n_bits_max to store what is left. |
| const uint32_t n_bits_max = FindLastSet(range_left - 1); |
| const OptimizeArrayStorageStat& s = stats[i]; |
| // The number of bits is in [0:n_bits_max] for the first set or when the |
| // previous value was out of the current range. |
| if (ind == 0 || n_bits_prev > n_bits_max) { |
| enc->PutRValue(s.n_bits, n_bits_max + 1, "n_bits"); |
| } else { |
| // If we know n_bits_prev is in [0:n_bits_max], we can save one bit as we |
| // know the value has to be different from before. |
| // If we store n_bits0 for a set, the following n_bits1 will be stored |
| // as: |
| // n_bits1 if n_bits1 < n_bits0, |
| // n_bits1 - 1 otherwise. |
| if (s.n_bits < n_bits_prev) { |
| enc->PutRValue(s.n_bits, n_bits_max, "n_bits"); |
| } else { |
| enc->PutRValue(s.n_bits - 1, n_bits_max, "n_bits"); |
| } |
| } |
| n_bits_prev = s.n_bits; |
| // Store the number of values. |
| enc->PutRange(s.count, 1, size_left, "count"); |
| |
| if (s.n_bits == 0) { |
| ind += s.count; |
| } else if (s.count == 1) { |
| // If we only have one number, we don't need to store the high order |
| // bit as it's always 1, otherwise it would fit in (n_bits_ - 1). |
| // And if that number is only on one bit to begin with, we don't need |
| // to store anything. |
| const uint16_t val = |
| (ind == 0) ? mapping[ind] : mapping[ind] - mapping[ind - 1] - 1; |
| if (s.n_bits > 1) { |
| const uint32_t high_order_bit = 1u << (s.n_bits - 1); |
| const uint32_t value_to_code = val ^ high_order_bit; |
| // After this value, the range will be at |
| // most range_left - high_order_bit. If it is smaller than the |
| // number of bits we use, we can just use a range. |
| if ((uint32_t)(range_left - high_order_bit) < high_order_bit) { |
| enc->PutRValue(value_to_code, range_left - high_order_bit, "value"); |
| } else { |
| enc->PutUValue(value_to_code, s.n_bits - 1, "value"); |
| } |
| } |
| range_left -= val; |
| ++ind; |
| } else { |
| // Store the values. |
| for (size_t j = 0; j < s.count; ++j) { |
| const uint16_t val = |
| (ind == 0) ? mapping[ind] : mapping[ind] - mapping[ind - 1] - 1; |
| if (range_left < (1u << s.n_bits)) { |
| enc->PutRValue(val, range_left, "value"); |
| } else { |
| enc->PutUValue(val, s.n_bits, "value"); |
| } |
| range_left -= val; |
| ++ind; |
| } |
| } |
| // Early exit if the cost is too high. |
| if (cost_max && enc->GetCost() >= *cost_max) break; |
| size_left -= s.count; |
| } |
| } |
| |
| // Helper function storing the mapping with one of the different methods. |
| // If cost_max is set, it is assumed 'enc' is an ANSEncCounter and there will be |
| // an early exit if its cost is greater than cost_max. |
| static void StoreMappingHelper(const uint16_t* const mapping, size_t size, |
| uint32_t range, |
| const OptimizeArrayStorageStat* const stats, |
| size_t stats_size, MappingMethod method, |
| std::optional<float> cost_max, |
| ANSEncBase* const enc) { |
| switch (method) { |
| case MappingMethod::kRange: { |
| // Store everything as a range by alternating between the beginning and |
| // end runs to get 'num_left' as small as possible. |
| // This method is usually good for few values in a big range. |
| for (uint32_t i = 0, j = size, num_left = size; num_left > 0; ++i) { |
| // Deal with the beginning runs. |
| if (i == 0) { |
| enc->PutRange(mapping[i], 0, range - num_left, "run_beg"); |
| } else { |
| enc->PutRange(mapping[i], mapping[i - 1] + 1, mapping[j] - num_left, |
| "run_beg"); |
| } |
| --num_left; |
| // Deal with the end runs. |
| if (num_left > 0) { |
| --j; |
| enc->PutRange(mapping[j], mapping[i] + num_left, |
| (i == 0) ? range - 1 : mapping[j + 1] - 1, "run_end"); |
| --num_left; |
| // early-out to avoid empty calls to PutRange() |
| if (mapping[i] + num_left + 1 == mapping[j]) break; |
| } |
| // Early exit if the cost is too high. |
| if (cost_max && enc->GetCost() >= *cost_max) break; |
| } |
| break; |
| } |
| case MappingMethod::kStoreVector: { |
| // Store everything as a list of tuples like in StoreVector. |
| // This method is usually good for many values in a big range. |
| StoreVectorForMapping(mapping, size, range, stats, stats_size, cost_max, |
| enc); |
| break; |
| } |
| case MappingMethod::kAdaptativeBit: { |
| // Store everything as a succession of probability bits indicating whether |
| // the index is used. |
| // This method is usually good for many values in a small range. |
| uint32_t num_zeros = range - size; |
| // Once there are no more zeros, it will cost nothing to store the bit. |
| for (uint32_t i = 0, j = 0; i < size; ++j) { |
| const uint32_t num_left = range - j; |
| if (enc->PutBit(j == mapping[i], num_zeros, num_left, "is_used")) { |
| ++i; |
| } else { |
| // early out to avoid calling PutBit() with proba 0 |
| if (--num_zeros == 0) break; |
| } |
| // Early exit if the cost is too high. |
| if (cost_max && enc->GetCost() >= *cost_max) break; |
| } |
| break; |
| } |
| case MappingMethod::kBit: { |
| // Store everything as a succession of bits indicating whether the index |
| // is used. |
| // This method is usually good for few values in a small range, and |
| // located at the beginning. |
| uint32_t num_zeros = range - size; |
| // We keep going until we cannot deduce what is left (i.e. as long as |
| // there is a 1 or 0). |
| for (uint32_t i = 0, j = 0; i < size && num_zeros > 0; ++j) { |
| if (enc->PutBool(j == mapping[i], "is_used")) { |
| ++i; |
| } else { |
| --num_zeros; |
| } |
| // Early exit if the cost is too high. |
| if (cost_max && enc->GetCost() >= *cost_max) break; |
| } |
| break; |
| } |
| } |
| } |
| |
| // stats[] stores number of bits, and numbers of consecutive values with |
| // that bit depth. |
| static uint32_t CollectOptimalStats(const uint16_t* const mapping, size_t size, |
| uint32_t range, |
| OptimizeArrayStorageStat* const stats) { |
| size_t stats_size = 0; |
| uint16_t val_max = 0, val_min = 0; |
| uint32_t range_left = range - size + 1; |
| uint16_t prev_val = 0; |
| for (size_t i = 0; i < size; ++i) { |
| assert(mapping[i] < range); |
| const uint16_t val = mapping[i] - prev_val; |
| prev_val = mapping[i] + 1; |
| // Check if the highest set bit is the same (faster than calling |
| // FindLastSet). |
| if (val < val_max && val >= val_min) { |
| ++stats[stats_size - 1].count; |
| } else { |
| stats[stats_size].count = 1; |
| stats[stats_size].n_bits = FindLastSet(val); |
| val_max = (1u << stats[stats_size].n_bits); |
| val_min = (val_max >> 1); |
| ++stats_size; |
| } |
| // Stop here as everything else is in [0, 1) hence is 0. |
| range_left -= val; |
| if (range_left == 1) break; |
| } |
| // Merge the pairs to give an optimal cost. |
| const uint8_t n_bits_max = FindLastSet(range_left - 1); |
| assert(n_bits_max <= kANSMaxRangeBits); |
| OptimizeArrayStorage(stats, &stats_size, |
| WP2Log2Fast(n_bits_max) + WP2Log2(size + 1)); |
| return stats_size; |
| } |
| |
| float StoreMapping(const uint16_t* const mapping, size_t size, uint32_t range, |
| int effort, OptimizeArrayStorageStat* const stats, |
| ANSEncBase* const enc) { |
| if (size == 0 || size == range) return 0.f; |
| // Check whether all is consecutive, starting at 0. |
| const bool is_series = (mapping[0] == 0 && mapping[size - 1] == size - 1); |
| ANSDebugPrefix prefix(enc, "mapping"); |
| if (enc != nullptr) { |
| enc->PutBool(is_series, "is_series"); |
| } |
| float cost_tot = 1.f; |
| if (is_series) return cost_tot; |
| |
| // Find the optimal stats[] sequence for the kStoreVector case |
| const size_t stats_size = CollectOptimalStats(mapping, size, range, stats); |
| |
| // Find the most cost-efficient mapping method. Since kBit cost is easy to |
| // compute, use it as initial value. |
| float cost_min = GetEffectiveRange(mapping, size, range); |
| MappingMethod best_method = MappingMethod::kBit; |
| constexpr std::array<MappingMethod, 3> methods = { |
| MappingMethod::kRange, MappingMethod::kStoreVector, |
| MappingMethod::kAdaptativeBit}; |
| const uint32_t num_extra_methods = effort == 0 ? 0 |
| : effort < 3 ? 1 |
| : effort < 6 ? 2 |
| : 3; |
| for (uint32_t i = 0; i < num_extra_methods; ++i) { |
| const MappingMethod method = methods[i]; |
| ANSEncCounter counter; |
| StoreMappingHelper(mapping, size, range, stats, stats_size, method, |
| /*cost_max=*/cost_min, &counter); |
| const float cost = counter.GetCost(); |
| if (cost < cost_min) { |
| cost_min = cost; |
| best_method = method; |
| } |
| } |
| cost_tot += cost_min; |
| |
| if (enc != nullptr) { |
| // Finalize the bitstream. |
| enc->PutRValue((uint32_t)best_method, kNumMapping, "method"); |
| StoreMappingHelper(mapping, size, range, stats, stats_size, best_method, |
| /*cost_max=*/std::nullopt, enc); |
| } |
| cost_tot += WP2Log2(kNumMapping); |
| |
| return cost_tot; |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| uint32_t PutLargeRange(uint32_t v, uint32_t min, uint32_t max, |
| ANSEncBase* const enc, WP2_OPT_LABEL) { |
| assert(v >= min); |
| assert(v <= max); |
| const uint32_t range = max - min + 1; |
| if (range <= kANSMaxRange) return enc->PutRange(v, min, max, label); |
| |
| v = v - min; |
| const uint32_t num_intervals = DivCeil(range, kANSMaxRange); |
| const uint32_t interval = v / kANSMaxRange; |
| if (range % kANSMaxRange == 0) { |
| enc->PutRValue(v % kANSMaxRange, kANSMaxRange, label); |
| enc->PutRValue(v / kANSMaxRange, num_intervals, label); |
| } else { |
| const bool is_last_interval = (interval == (num_intervals - 1)); |
| enc->PutBit(is_last_interval, range - (range % kANSMaxRange), range, label); |
| if (is_last_interval) { |
| enc->PutRValue(v % kANSMaxRange, range % kANSMaxRange, label); |
| } else { |
| enc->PutRValue(v % kANSMaxRange, kANSMaxRange, label); |
| enc->PutRValue(v / kANSMaxRange, num_intervals - 1, label); |
| } |
| } |
| return min + v; |
| } |
| |
| uint32_t ReadLargeRange(uint32_t min, uint32_t max, ANSDec* const dec, |
| WP2_OPT_LABEL) { |
| const uint32_t range = max - min + 1; |
| if (range <= kANSMaxRange) return dec->ReadRange(min, max, label); |
| |
| const uint32_t num_intervals = DivCeil(range, kANSMaxRange); |
| if (range % kANSMaxRange == 0) { |
| uint32_t v = dec->ReadRValue(kANSMaxRange, label); |
| v += kANSMaxRange * dec->ReadRValue(num_intervals, label); |
| return min + v; |
| } else { |
| const bool is_last_interval = |
| dec->ReadBit(range - (range % kANSMaxRange), range, label); |
| if (is_last_interval) { |
| return min + dec->ReadRValue(range % kANSMaxRange, label) + |
| kANSMaxRange * (num_intervals - 1); |
| } else { |
| uint32_t v = dec->ReadRValue(kANSMaxRange, label); |
| v += kANSMaxRange * dec->ReadRValue(num_intervals - 1, label); |
| return min + v; |
| } |
| } |
| } |
| |
| } // namespace WP2 |