| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ----------------------------------------------------------------------------- |
| |
| // test coding elements in ANS |
| |
| #include <algorithm> |
| #include <array> |
| #include <cmath> |
| #include <cstdint> |
| #include <cstdio> |
| #include <cstdlib> |
| #include <cstring> |
| #include <random> |
| #include <vector> |
| |
| #include "examples/example_utils.h" |
| #include "include/helpers.h" |
| #include "src/dsp/dsp.h" |
| #include "src/utils/ans.h" |
| #include "src/utils/ans_utils.h" |
| #include "src/utils/data_source.h" |
| #include "src/utils/quantizer.h" |
| #include "src/utils/random.h" |
| |
| namespace WP2 { |
| namespace { |
| |
| constexpr uint32_t kBaseSeed = 463634; |
| constexpr uint32_t kTestFactor = 10; // The higher the longer this test takes. |
| constexpr bool kVerbose = false; // Additional debug info to stdout. |
| |
| //------------------------------------------------------------------------------ |
| |
| // Draws a random bit and have its probability be 0 or kProbaMax N% of the time. |
| constexpr uint32_t kProbaMax = PROBA_MAX / 2; |
| inline void DrawRandomBit(uint32_t N, UniformIntDistribution* const random, |
| uint32_t* const p, uint32_t* const b) { |
| if ((N > 0) && (random->Get(0u, 100u) <= N)) { |
| // Force to reach the bounds of probabilities. |
| *p = kProbaMax * random->Get(0u, 1u); |
| } else { |
| *p = random->Get(0u, kProbaMax); |
| } |
| if (*p == 0) { |
| *b = 1; |
| } else if (*p == kProbaMax) { |
| *b = 0; |
| } else { |
| *b = random->Get(0u, 1u); |
| } |
| } |
| |
| struct Token { |
| enum Type { Bit, Bool, ASymbol, Symbol, Range, MinMax, Uniform, Signed, ABit, |
| None }; |
| explicit Token(Type token_type = None, uint32_t token_first = 0, |
| uint32_t token_second = 0, uint32_t token_third = 0) |
| : type(token_type), |
| first(token_first), |
| second(token_second), |
| third(token_third) {} |
| Type type; |
| int32_t first = 0; |
| int32_t second = 0; |
| int32_t third = 0; |
| }; |
| |
| // Generate some random signal and append it to the ANS encoder. |
| // It creates 'n_elements' of type chosen randomly from the ones in |
| // 'used_types'. |
| // 'tokens' also gets updated. |
| void GenerateRandomEnc(uint32_t n_elements, |
| const std::vector<Token::Type>& used_types, |
| uint32_t max_asymbol, uint32_t n_dict, |
| uint32_t max_symbol, uint32_t max_range, |
| ANSBinSymbol* const bin_symbol, |
| ANSAdaptiveSymbol* const asymbol, |
| ANSDictionaries* const dicts, |
| UniformIntDistribution* const random, |
| std::vector<Token>* tokens, ANSEncBase* const enc) { |
| const bool has_symbols = (std::find(used_types.begin(), used_types.end(), |
| Token::Symbol) != used_types.end()); |
| uint32_t n_dict_ini = 0; |
| if (has_symbols) { |
| assert(n_dict > 0); |
| n_dict_ini = dicts->size(); |
| for (uint32_t i = 0; i < n_dict; ++i) { |
| ASSERT_WP2_OK(dicts->Add(max_symbol)); |
| } |
| } |
| const uint32_t size_ini = tokens->size(); |
| tokens->resize(size_ini + n_elements); |
| // Pre-compute the signal. |
| for (uint32_t i = size_ini; i < tokens->size(); ++i) { |
| const uint32_t type_idx = random->Get(0u, (uint32_t)used_types.size() - 1u); |
| const Token::Type t = used_types[type_idx]; |
| switch (t) { |
| case Token::Bit: { |
| uint32_t p, b; |
| DrawRandomBit(10, random, &p, &b); |
| (*tokens)[i] = Token(t, b, p); |
| break; |
| } |
| case Token::ABit: { |
| uint32_t b; |
| const uint32_t thresh = random->Get(0u, 10u); |
| b = (random->Get(0u, 10u) > thresh) ? 1 : 0; |
| (*tokens)[i] = Token(t, b); |
| break; |
| } |
| case Token::Bool: { |
| const bool b = random->Get(0u, 1u); |
| (*tokens)[i] = Token(t, b); |
| break; |
| } |
| case Token::ASymbol: { |
| const uint32_t s = random->Get(0u, max_asymbol - 1u); |
| (*tokens)[i] = Token(t, s); |
| break; |
| } |
| case Token::Symbol: { |
| const uint32_t d = n_dict_ini + random->Get(0u, n_dict - 1u); |
| const uint32_t s = random->Get(0u, max_symbol - 1u); |
| (*dicts)[d]->RecordSymbol(s); |
| (*tokens)[i] = Token(t, d, s); |
| break; |
| } |
| case Token::Uniform: |
| case Token::Signed: { |
| const uint32_t U = random->Get(0u, kANSMaxUniformBits); |
| const uint32_t v = random->Get(0u, (1u << U) - 1u); |
| (*tokens)[i] = Token(t, v, U); |
| break; |
| } |
| case Token::Range: { |
| const uint32_t R = |
| random->Get(1u, max_range ? max_range : kANSMaxRange); |
| const uint32_t v = random->Get(0u, R - 1u); |
| (*tokens)[i] = Token(t, v, R); |
| break; |
| } |
| case Token::MinMax: { |
| const uint32_t max = |
| random->Get(0u, max_range ? max_range - 1u : kANSMaxRange); |
| const uint32_t min = random->Get(0u, max); |
| const uint32_t v = random->Get(min, max); |
| (*tokens)[i] = Token(t, v, min, max); |
| break; |
| } |
| default: |
| assert(false); |
| } |
| } |
| if (has_symbols && dicts != nullptr) { |
| EXPECT_WP2_OK(dicts->ToCodingTable()); |
| } |
| for (uint32_t i = size_ini; i < tokens->size(); ++i) { |
| const Token& tok = (*tokens)[i]; |
| switch (tok.type) { |
| case Token::Bit: { |
| enc->PutBit(tok.first, tok.second, kProbaMax, "bit"); |
| break; |
| } |
| case Token::ABit: { |
| enc->PutABit(tok.first, bin_symbol, "abit"); |
| break; |
| } |
| case Token::Bool: { |
| enc->PutBool(tok.first, "bool"); |
| break; |
| } |
| case Token::ASymbol: { |
| if (asymbol != nullptr) enc->PutASymbol(tok.second, asymbol, "asymbol"); |
| break; |
| } |
| case Token::Symbol: { |
| enc->PutSymbol(tok.second, *(*dicts)[tok.first], "symbol"); |
| break; |
| } |
| case Token::Uniform: { |
| enc->PutUValue(tok.first, tok.second, "U_value"); |
| break; |
| } |
| case Token::Signed: { |
| enc->PutSUValue((int32_t)tok.first - ((1 << tok.second) >> 1), |
| tok.second, "S_value"); |
| break; |
| } |
| case Token::Range: { |
| enc->PutRValue(tok.first, tok.second, "R_value"); |
| break; |
| } |
| case Token::MinMax: { |
| enc->PutRange(tok.first, tok.second, tok.third, "minmax"); |
| break; |
| } |
| default: |
| assert(false); |
| } |
| } |
| } |
| |
| // Checks that an ANS buffer contains a list of tokens. |
| // 'codes' is the usual codes used for symbols in ANS: a vector of SymbolInfo |
| // per dictionary. |
| bool VerifyBuffer(const uint8_t* const buf, uint32_t size, |
| const std::vector<ANSCodes>* const codes, |
| ANSBinSymbol* const bin_symbol, |
| ANSAdaptiveSymbol* const asymbol, |
| const std::vector<Token>& tokens) { |
| ExternalDataSource data_source(buf, size); |
| ANSDec dec(&data_source); |
| for (const Token& tok : tokens) { |
| switch (tok.type) { |
| case Token::Bit: |
| if (dec.ReadBit(tok.second, kProbaMax, "bit") != (uint32_t)tok.first) { |
| printf("Bit value mismatch."); |
| return false; |
| } |
| break; |
| case Token::ABit: |
| assert(bin_symbol != nullptr); |
| if (dec.ReadABit(bin_symbol, "abit") != (uint32_t)tok.first) { |
| printf("ABit value mismatch."); |
| return false; |
| } |
| break; |
| case Token::Bool: |
| if (dec.ReadBool("bool") != (uint32_t)tok.first) { |
| printf("Bool value mismatch."); |
| return false; |
| } |
| break; |
| case Token::ASymbol: |
| assert(asymbol != nullptr); |
| if (dec.ReadASymbol(asymbol, "asymbol") != (uint32_t)tok.second) { |
| printf("ASymbol value mismatch."); |
| return false; |
| } |
| break; |
| case Token::Symbol: |
| assert(codes != nullptr); |
| if (dec.ReadSymbol((*codes)[tok.first].data(), ANS_LOG_TAB_SIZE, |
| "symbol") != (uint32_t)tok.second) { |
| printf("Symbol value mismatch."); |
| return false; |
| } |
| break; |
| case Token::Uniform: |
| if (dec.ReadUValue(tok.second, "U_value") != (uint32_t)tok.first) { |
| printf("U value mismatch."); |
| return false; |
| } |
| break; |
| case Token::Signed: |
| if (dec.ReadSUValue(tok.second, "S_value") + ((1 << tok.second) >> 1) != |
| tok.first) { |
| printf("S value mismatch."); |
| return false; |
| } |
| break; |
| case Token::Range: |
| if (dec.ReadRValue(tok.second, "R_value") != (uint32_t)tok.first) { |
| printf("R value mismatch."); |
| return false; |
| } |
| break; |
| case Token::MinMax: |
| if (dec.ReadRange(tok.second, tok.third, "minmax") != tok.first) { |
| printf("MinMax value mismatch."); |
| return false; |
| } |
| break; |
| default: |
| assert(0); |
| break; |
| } |
| } |
| return ((dec.GetStatus() == WP2_STATUS_OK) && |
| (data_source.GetNumNextBytes() == 0)); |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| class TestANS : public testing::Test { |
| void SetUp() override { |
| WP2DspReset(); |
| WP2MathInit(); |
| ANSInit(); |
| } |
| public: |
| static constexpr uint32_t kMaxRange = 1 << kANSMaxRangeBits; |
| }; |
| |
| // Test ANS for bits, U-value and R-values. |
| TEST_F(TestANS, Test0) { |
| const uint32_t kNumTokens = 700 * kTestFactor; |
| ANSEnc enc; |
| UniformIntDistribution random(kBaseSeed); |
| std::vector<Token> tokens; |
| |
| GenerateRandomEnc( |
| kNumTokens, {{Token::Bit, Token::Uniform, Token::Signed, Token::Range}}, |
| /*max_asymbol=*/0, /*n_dict=*/0, /*max_symbol=*/0, /*max_range=*/0, |
| /*bin_symbol=*/nullptr, /*asymbol=*/nullptr, /*dicts=*/nullptr, &random, |
| &tokens, &enc); |
| |
| EXPECT_WP2_OK(enc.AssembleToBitstream(true)); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| const uint32_t size = bits.size(); |
| const uint8_t* buf = bits.data(); |
| const int size_expected = (int)std::lround(enc.GetCostFull()); |
| if (kVerbose) { |
| printf("%u symbols -> actual bits: %d / expected bits:%d\n", |
| (uint32_t)tokens.size(), (int)size * 8, size_expected); |
| } |
| EXPECT_TRUE(VerifyBuffer(buf, size, /*codes=*/nullptr, /*bin_symbol=*/nullptr, |
| /*asymbol=*/nullptr, tokens)); |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| // Test different symbol inputs. |
| TEST_F(TestANS, Test2) { |
| const uint32_t kNumSymbols = 700 * kTestFactor; |
| UniformIntDistribution random(kBaseSeed); |
| const uint32_t kMaxSymbol = 256; |
| ANSCodes codes; |
| EXPECT_TRUE(codes.resize(31)); // will be resized later |
| |
| for (uint32_t type : {0, 1, 2}) { |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| EXPECT_WP2_OK(dicts.Add(kMaxSymbol)); |
| ANSDictionary* const dict = dicts.back(); |
| // Generate the symbols and fill the encoder. |
| std::vector<uint32_t> symbols(kNumSymbols); |
| Vector_u32 counts; |
| ASSERT_TRUE(counts.resize(kMaxSymbol)); |
| for (auto& c : counts) c = 0; |
| for (uint32_t& s : symbols) { |
| s = random.Get(0u, kMaxSymbol - 1u); |
| ++counts[s]; |
| dict->RecordSymbol(s); |
| } |
| // Depending on the case, change counts. |
| switch (type) { |
| case 0: { |
| // Nominal case, use the normal probabilities. |
| const Vector_u32& counts0 = dict->Counts(); |
| EXPECT_EQ(dict->MaxSymbol(), kMaxSymbol); |
| for (uint32_t i = 0; i < kMaxSymbol; ++i) { |
| EXPECT_EQ(counts[i], counts0[i]) << "Count error at #" << i; |
| } |
| break; |
| } |
| case 1: { |
| // Quantize using MaxFreq. |
| const uint32_t max_freq = |
| *std::max_element(std::begin(counts), std::end(counts)) / 2; |
| EXPECT_GE(max_freq, 1u); |
| EXPECT_TRUE(ANSCountsQuantize(false, max_freq, kMaxSymbol, &counts[0], |
| nullptr)); |
| EXPECT_WP2_OK(dict->SetQuantizedCounts(counts.data())); |
| break; |
| } |
| case 2: { |
| // Impose a pre-defined random set of probabilities. |
| for (uint32_t& c : counts) { |
| c = (c > 0u) ? random.Get(1u, 100u) : 0u; |
| } |
| EXPECT_WP2_OK(dict->SetQuantizedCounts(counts.data())); |
| break; |
| } |
| default: |
| EXPECT_FALSE(true); |
| } |
| EXPECT_WP2_OK(dict->ToCodingTable()); |
| for (uint32_t s : symbols) { |
| enc.PutSymbol(s, *dict, "id"); |
| } |
| const float cost1 = enc.GetCost(); |
| const float cost2 = enc.GetCost(dicts); |
| // Error is usually less than 0.1%. |
| EXPECT_NEAR(cost1, cost2, 0.001 * cost1) << "type: " << type; |
| |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| |
| // Decode the buffer. |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| ASSERT_WP2_OK(ANSCountsToSpreadTable(&counts[0], kMaxSymbol, |
| ANS_TAB_SIZE, codes)); |
| for (uint32_t i = 0; i < kNumSymbols; ++i) { |
| const uint32_t s = dec.ReadSymbol(codes.data(), ANS_LOG_TAB_SIZE, "id"); |
| EXPECT_EQ(s, symbols[i]) << "Error case " << type << " at position " << i; |
| } |
| } |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| // Test some convenience functions: Freeze, GetCounts, StoreVector, ReadVector. |
| TEST_F(TestANS, Test3) { |
| const uint32_t kNumSymbols = 700 * kTestFactor; |
| UniformIntDistribution random(kBaseSeed); |
| const uint32_t kMaxSymbol = 256; |
| |
| // Generate the symbols and fill the dictionary. |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| EXPECT_WP2_OK(dicts.Add(kMaxSymbol)); |
| ANSDictionary* const dict = dicts.back(); |
| std::vector<uint32_t> symbols(kNumSymbols); |
| for (uint32_t& i : symbols) { |
| i = random.Get(0u, kMaxSymbol - 1u); |
| dict->RecordSymbol(i); |
| } |
| |
| // Write the stream. |
| constexpr uint32_t vector_size = 10; |
| constexpr uint32_t num_vectors = 10; |
| uint32_t vector[num_vectors * vector_size]; |
| uint32_t nnzs[num_vectors]; |
| { |
| const Vector_u32& counts = dict->Counts(); |
| const uint32_t size = dict->MaxSymbol(); |
| const uint32_t max_freq = *std::max_element(counts.begin(), counts.end()); |
| |
| enc.PutRange(max_freq, 1, kANSMaxRange, "max_freq"); |
| enc.AddDebugPrefix("vector"); |
| std::vector<OptimizeArrayStorageStat> stats(size); |
| const float cost = |
| StoreVector(counts.data(), size, max_freq, stats.data(), &enc); |
| if (kVerbose) printf("Storage cost: %f\n", cost); |
| // Store random vectors. |
| for (uint32_t max_nnz = 1; max_nnz <= num_vectors; ++max_nnz) { |
| uint32_t* const sub_vector = &vector[vector_size * (max_nnz - 1)]; |
| std::fill(sub_vector, sub_vector + vector_size, 0u); |
| for (uint32_t j = 0; j < max_nnz; ++j) { |
| const uint32_t freq = random.Get(0, 100); |
| sub_vector[j == max_nnz - 1 ? vector_size - 1 |
| : random.Get(0u, vector_size - 1)] = freq; |
| } |
| nnzs[max_nnz - 1] = |
| vector_size - std::count(sub_vector, sub_vector + vector_size, 0); |
| const uint32_t val_upper = |
| *std::max_element(sub_vector, sub_vector + vector_size); |
| StoreVectorNnz(sub_vector, vector_size, nnzs[max_nnz - 1], val_upper, |
| stats.data(), &enc); |
| } |
| EXPECT_WP2_OK(dict->ToCodingTable()); |
| for (const uint32_t& i : symbols) enc.PutSymbol(i, *dict, "id"); |
| } |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| |
| // Decode the meta-info. |
| ANSCodes codes; |
| EXPECT_TRUE(codes.resize(535)); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| { |
| const uint32_t max_freq = dec.ReadRange(1, kANSMaxRange, "max_freq"); |
| Vector_u32 counts; |
| EXPECT_TRUE(counts.resize(kMaxSymbol)); |
| dec.AddDebugPrefix("vector"); |
| ReadVector(&dec, max_freq, counts); |
| // Read random vectors. |
| for (uint32_t max_nnz = 1; max_nnz <= num_vectors; ++max_nnz) { |
| std::array<uint32_t, vector_size> vector_read; |
| const uint32_t* const sub_vector = &vector[vector_size * (max_nnz - 1)]; |
| const uint32_t val_upper = |
| *std::max_element(sub_vector, sub_vector + vector_size); |
| ReadVectorNnz(&dec, nnzs[max_nnz - 1], val_upper, vector_read); |
| EXPECT_TRUE( |
| std::equal(sub_vector, sub_vector + vector_size, &vector_read[0])); |
| } |
| ASSERT_WP2_OK(ANSCountsToSpreadTable(&counts[0], kMaxSymbol, |
| ANS_TAB_SIZE, codes)); |
| } |
| |
| // Decode the buffer and make sure it matches the input |
| for (uint32_t i = 0; i < kNumSymbols; ++i) { |
| const uint32_t s = dec.ReadSymbol(codes.data(), ANS_LOG_TAB_SIZE, "id"); |
| EXPECT_EQ(s, symbols[i]) << "Error at position #" << i; |
| } |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| struct TokenInfo { |
| std::string name; |
| Token::Type type; |
| float err_threshold; // Threshold for the estimation error, in %. |
| }; |
| |
| class TestANSCostEstimation |
| : public testing::TestWithParam<std::tuple<uint32_t, TokenInfo>> {}; |
| |
| // Test the cost estimation method |
| TEST_P(TestANSCostEstimation, Simple) { |
| // Disable this test with WP2_ENC_DEC_MATCH as it considerably increases the |
| // size of the stream (by adding hashes) and therefore makes the thresholds |
| // useless. |
| #if !defined(WP2_ENC_DEC_MATCH) |
| const uint32_t kMaxSymbol = 256; |
| const uint32_t kMaxASymbol = 10; |
| const uint32_t kNDict = 3; |
| const uint32_t kNValues = 40000 * kTestFactor; |
| const TokenInfo& info = std::get<1>(GetParam()); |
| |
| WP2MathInit(); |
| ANSInit(); |
| |
| int size = 0; |
| ANSEnc enc0; |
| ANSEncCounter enc1; |
| for (ANSEncBase* const enc : {(ANSEncBase*)&enc0, (ANSEncBase*)&enc1}) { |
| UniformIntDistribution random(/*seed=*/std::get<0>(GetParam())); |
| ANSDictionaries sym_dicts; |
| ANSBinSymbol bin_symbol; |
| ANSAdaptiveSymbol asymbol; |
| asymbol.SetAdaptationSpeed(ANSAdaptiveSymbol::Method::kAOM); |
| asymbol.InitFromUniform(APROBA_MAX_SYMBOL); |
| std::vector<Token> tokens; |
| |
| const uint32_t n_dict = (info.type == Token::Symbol) ? kNDict : 0; |
| const uint32_t max_symbol = (info.type == Token::Symbol) ? kMaxSymbol : 0; |
| const uint32_t max_asymbol = |
| (info.type == Token::ASymbol) ? kMaxASymbol : 0; |
| const uint32_t max_range = 0; |
| ANSBinSymbol* const bin_symbol_stats = |
| (info.type == Token::ABit) ? &bin_symbol : nullptr; |
| ANSAdaptiveSymbol* const asymbol_stats = |
| (info.type == Token::ASymbol) ? &asymbol : nullptr; |
| ANSDictionaries* const dicts = |
| (info.type == Token::Symbol) ? &sym_dicts : nullptr; |
| |
| GenerateRandomEnc(kNValues, {info.type}, max_asymbol, n_dict, max_symbol, |
| max_range, bin_symbol_stats, asymbol_stats, dicts, |
| &random, &tokens, enc); |
| |
| if (enc == &enc0) { |
| EXPECT_WP2_OK(enc0.AssembleToBitstream()); |
| Vector_u8 bits0; |
| EXPECT_WP2_OK(enc0.WriteBitstreamTo(bits0)); |
| size = (int)(8 * bits0.size()); // store for later comparison |
| if (info.type != Token::Symbol) { // omit the dictionary case for now |
| ANSBinSymbol new_bin_symbol; |
| ANSAdaptiveSymbol new_asymbol; |
| new_asymbol.SetAdaptationSpeed(ANSAdaptiveSymbol::Method::kAOM); |
| new_asymbol.InitFromUniform(APROBA_MAX_SYMBOL); |
| EXPECT_TRUE(VerifyBuffer(bits0.data(), bits0.size(), |
| /*codes=*/nullptr, &new_bin_symbol, |
| &new_asymbol, tokens)); |
| } |
| } |
| |
| const int size_estimated = (int)std::lround(enc->GetCostFull(sym_dicts)); |
| const float error = 100.f * size / size_estimated - 100.f; |
| EXPECT_LE(fabs((double)error), info.err_threshold); |
| if (kVerbose) { |
| printf("%s: actual bits: %d / estimated bits:%d. %1.3f%% bigger.\n", |
| info.name.c_str(), size, size_estimated, error); |
| } |
| } |
| #endif // WP2_ENC_DEC_MATCH |
| } |
| |
| INSTANTIATE_TEST_SUITE_P( |
| TestANSCostEstimationInstantiation, TestANSCostEstimation, |
| testing::Combine(testing::Values(kBaseSeed), |
| testing::ValuesIn(std::vector<TokenInfo>{ |
| {"Bit ", Token::Bit, 0.15f}, |
| {"Bool ", Token::Bool, 0.01f}, |
| {"ASymbol", Token::ASymbol, 2.0f}, |
| {"Symbol ", Token::Symbol, 0.02f}, |
| {"Range ", Token::Range, 0.05f}, |
| {"MinMax ", Token::MinMax, 0.06f}, |
| {"Uniform", Token::Uniform, 0.01f}, |
| {"Signed ", Token::Signed, 0.01f}, |
| {"ABit ", Token::ABit, 3.2f}}))); |
| |
| //------------------------------------------------------------------------------ |
| |
| // Test the Quantizer. |
| TEST_F(TestANS, Test5a) { |
| UniformIntDistribution random(/*seed=*/kBaseSeed); |
| const uint32_t kMaxSymbol = 256; |
| const uint32_t kNumSymbols = 7000 * kTestFactor; |
| |
| // Create some stream with symbols. |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| EXPECT_WP2_OK(dicts.Add(kMaxSymbol)); |
| ANSDictionary* const dict = dicts.back(); |
| for (uint32_t i = 0; i < kNumSymbols; ++i) { |
| dict->RecordSymbol(random.Get(0u, kMaxSymbol - 1u)); |
| } |
| |
| // Compute the sparse histogram. |
| const Vector_u32& counts = dict->Counts(); |
| std::vector<uint32_t> histogram(kMaxSymbol); |
| std::vector<uint16_t> mapping(kMaxSymbol); |
| uint32_t size_sparse = 0; |
| for (uint32_t i = 0; i < dict->MaxSymbol(); ++i) { |
| if (counts[i] == 0) continue; |
| mapping[size_sparse] = (uint16_t)i; |
| histogram[size_sparse++] = counts[i]; |
| } |
| |
| // Optimize the storage of proba + symbols. |
| Quantizer quantizer; |
| EXPECT_WP2_OK(quantizer.Allocate(kMaxSymbol)); |
| |
| // Test the statefullness of the quantizer. |
| EXPECT_TRUE(quantizer.Quantize(histogram.data(), mapping.data(), size_sparse, |
| kMaxSymbol, kNumSymbols, |
| /*cost_max=*/std::numeric_limits<float>::max(), |
| /*cost_offset=*/0.f, /*effort=*/5)); |
| const float cost1 = quantizer.GetBest()->cost; |
| EXPECT_FALSE(quantizer.Quantize( |
| histogram.data(), mapping.data(), size_sparse, kMaxSymbol, kNumSymbols, |
| /*cost_max=*/std::numeric_limits<float>::max(), |
| /*cost_offset=*/0.f, /*effort=*/5)); |
| // Now, use twice less data which will result in a lower cost. |
| EXPECT_TRUE(quantizer.Quantize(histogram.data(), mapping.data(), |
| size_sparse / 2, kMaxSymbol, kNumSymbols, |
| /*cost_max=*/std::numeric_limits<float>::max(), |
| /*cost_offset=*/0.f, /*effort=*/5)); |
| const float cost2 = quantizer.GetBest()->cost; |
| EXPECT_LT(cost2, cost1); |
| |
| for (int effort : {0, 5, 9}) { |
| quantizer.ResetBest(); |
| quantizer.Quantize(histogram.data(), mapping.data(), size_sparse, |
| kMaxSymbol, kNumSymbols, |
| /*cost_max=*/std::numeric_limits<float>::max(), |
| /*cost_offset=*/0.f, effort); |
| Quantizer::Config* config = quantizer.GetBest(); |
| |
| // Display results. |
| if (kVerbose) { |
| printf("Cost: %f\n", config->cost); |
| if (config->param.is_sparse) { |
| printf("Histogram is sparse:\n(index, count - 1):\n"); |
| for (uint32_t i = 0; i < config->size_to_write; ++i) { |
| printf("(%d,%d) ", mapping[i], config->histogram_to_write[i]); |
| } |
| printf("\n"); |
| } else { |
| printf("Histogram is not sparse:\n"); |
| for (uint32_t i = 0; i < config->size_to_write; ++i) { |
| printf("%d ", config->histogram_to_write[i]); |
| } |
| printf("\n"); |
| } |
| if (config->param.type == Quantizer::ConfigType::Huffman) { |
| printf("The counts above should be interpreted as exponents of 2.\n"); |
| } |
| } |
| } |
| } |
| |
| TEST_F(TestANS, Test5b) { |
| constexpr uint32_t kSize = 1000; |
| constexpr uint32_t kIndex = 501; |
| constexpr uint32_t kMaxCount = 10000; |
| constexpr uint32_t kMaxCountQuantized = 100; |
| std::vector<uint32_t> counts(kSize, 0); |
| counts[kIndex] = kMaxCount; |
| |
| // Test that the cost of non-quantized histogram is 0. |
| const float cost1 = ANSCountsCost(&counts[0], counts.size()); |
| EXPECT_EQ(cost1, 0.f); |
| |
| // Test that the cost of quantized histogram is 0. |
| std::vector<uint32_t> distribution = counts; |
| float cost2; |
| ANSCountsQuantize(/*do_expand=*/false, kMaxCountQuantized, counts.size(), |
| &counts[0], &distribution[0], &cost2); |
| EXPECT_EQ(distribution[kIndex], kMaxCountQuantized); |
| EXPECT_EQ(cost2, 0.f); |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| /** |
| * Test dedicated to encoding of integer in a range. |
| */ |
| TEST_F(TestANS, Test6a) { |
| constexpr uint32_t kNumTokens = 700 * kTestFactor; |
| // Verify behavior with randomized inputs. |
| for (Token::Type type : {Token::Range, Token::MinMax}) { |
| UniformIntDistribution random(kBaseSeed); |
| ANSEnc encoder; |
| std::vector<Token> tokens; |
| |
| GenerateRandomEnc(kNumTokens, {type}, /*max_asymbol=*/0, /*n_dict=*/0, |
| /*max_symbol=*/0, kMaxRange, /*bin_symbol=*/nullptr, |
| /*asymbol=*/nullptr, /*dicts=*/nullptr, &random, &tokens, |
| &encoder); |
| EXPECT_WP2_OK(encoder.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(encoder.WriteBitstreamTo(bits)); |
| EXPECT_TRUE(VerifyBuffer(bits.data(), bits.size(), |
| /*codes=*/nullptr, /*bin_symbol=*/nullptr, |
| /*asymbol=*/nullptr, tokens)); |
| } |
| } |
| |
| TEST_F(TestANS, Test6b) { |
| // Exhaustive verification for small integer range going from 1 to |
| // max_tested_range. |
| const uint32_t max_tested_range = 256; |
| ANSEnc encoder; |
| std::vector<Token> tokens; |
| tokens.reserve(max_tested_range * (max_tested_range - 1) / 2); |
| for (uint16_t range = 1; range < max_tested_range; ++range) { |
| for (uint16_t value = 0; value < range; ++value) { |
| encoder.PutRValue(value, range, "R_value"); |
| tokens.emplace_back(Token::Range, value, range); |
| } |
| } |
| EXPECT_WP2_OK(encoder.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(encoder.WriteBitstreamTo(bits)); |
| EXPECT_TRUE(VerifyBuffer(bits.data(), bits.size(), |
| /*codes=*/nullptr, /*bin_symbol=*/nullptr, |
| /*asymbol=*/nullptr, tokens)); |
| } |
| |
| TEST_F(TestANS, Test6c) { |
| // Verification for large integer ranges going from min_tested_range to |
| // kMaxRange. |
| const uint32_t min_tested_range = kMaxRange - 255; |
| ANSEnc encoder; |
| std::vector<Token> tokens; |
| tokens.reserve(kMaxRange * (kMaxRange - 1) / 2 - |
| (min_tested_range - 1) * (min_tested_range - 2) / 2); |
| for (uint16_t range = min_tested_range; range < kMaxRange; ++range) { |
| for (uint16_t value = range / 2; value < range; ++value) { |
| encoder.PutRValue(value, range, "R_value"); |
| tokens.emplace_back(Token::Range, value, range); |
| } |
| } |
| EXPECT_WP2_OK(encoder.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(encoder.WriteBitstreamTo(bits)); |
| EXPECT_TRUE(VerifyBuffer(bits.data(), bits.size(), |
| /*codes=*/nullptr, /*bin_symbol=*/nullptr, |
| /*asymbol=*/nullptr, tokens)); |
| } |
| |
| TEST_F(TestANS, Test6d) { |
| // Sparse verification for min/max ranges. |
| ANSEnc encoder; |
| std::vector<Token> tokens; |
| for (uint16_t hi = 0; hi < kMaxRange; hi += hi / 2 + 1) { |
| for (uint16_t lo = 0; lo <= hi; lo += lo / 2 + 1) { |
| for (uint16_t value = lo; value <= hi; ++value) { |
| encoder.PutRange(value, lo, hi, "minmax"); |
| tokens.emplace_back(Token::MinMax, value, lo, hi); |
| } |
| } |
| } |
| EXPECT_WP2_OK(encoder.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(encoder.WriteBitstreamTo(bits)); |
| EXPECT_TRUE(VerifyBuffer(bits.data(), bits.size(), |
| /*codes=*/nullptr, /*bin_symbol=*/nullptr, |
| /*asymbol=*/nullptr, tokens)); |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| // Generate the codes for each dictionary, given a list of tokens. |
| void GetCodesFromTokens(const std::vector<uint32_t>& max_symbols, |
| const std::vector<Token>& tokens, |
| std::vector<ANSCodes>& codes) { |
| // Compute the counts of symbols. |
| const uint32_t n_dicts = max_symbols.size(); |
| std::vector<std::vector<uint32_t>> counts(n_dicts); |
| for (uint32_t i = 0; i < n_dicts; ++i) counts[i].resize(max_symbols[i], 0); |
| for (const Token& tok : tokens) { |
| if (tok.type == Token::Symbol) ++counts[tok.first][tok.second]; |
| } |
| |
| codes.resize(n_dicts); |
| for (uint32_t i = 0; i < n_dicts; ++i) { |
| EXPECT_WP2_OK(ANSCountsToSpreadTable( |
| &counts[i][0], max_symbols[i], ANS_TAB_SIZE, codes[i])); |
| } |
| } |
| |
| // Test for miscellaneous ANS functions |
| TEST_F(TestANS, Test7) { |
| const uint32_t kNumElements = 700 * kTestFactor; |
| WP2MathInit(); |
| ANSInit(); |
| // Test for adding new dictionaries after some symbols are inserted. |
| { |
| const std::vector<uint32_t> max_symbols = {17, 17, 13, 29, 101}; |
| std::vector<Token> tokens; |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| UniformIntDistribution random(kBaseSeed); |
| for (uint32_t max_symbol : max_symbols) { |
| GenerateRandomEnc(kNumElements, {Token::Symbol}, /*max_asymbol=*/0, |
| /*n_dict=*/1, max_symbol, /*max_range=*/0, |
| /*bin_symbol=*/nullptr, /*asymbol=*/nullptr, &dicts, |
| &random, &tokens, &enc); |
| } |
| std::vector<ANSCodes> codes; |
| GetCodesFromTokens(max_symbols, tokens, codes); |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| EXPECT_TRUE(VerifyBuffer(bits.data(), bits.size(), &codes, |
| /*bin_symbol=*/nullptr, /*asymbol=*/nullptr, |
| tokens)); |
| } |
| // Test appending. |
| { |
| const std::vector<uint32_t> max_symbols[2] = {{19, 21, 3, 512}, |
| {7, 5, 405}}; |
| UniformIntDistribution random(kBaseSeed); |
| ANSEnc enc[2]; |
| ANSDictionaries dicts[2]; |
| std::vector<Token> tokens; |
| for (uint32_t i = 0; i < 2; ++i) { |
| for (uint32_t max_symbol : max_symbols[i]) { |
| GenerateRandomEnc( |
| kNumElements, {Token::Symbol}, /*max_asymbol=*/0, /*n_dict=*/1, |
| max_symbol, /*max_range=*/0, /*bin_symbol=*/nullptr, |
| /*asymbol=*/nullptr, &dicts[i], &random, &tokens, &enc[i]); |
| } |
| } |
| EXPECT_WP2_OK(enc[0].AppendTokens(enc[1], 0, enc[1].NumTokens())); |
| EXPECT_WP2_OK(dicts[0].AppendAndClear(&dicts[1])); |
| EXPECT_WP2_OK(dicts[0].ToCodingTable()); |
| EXPECT_WP2_OK(enc[0].AssembleToBitstream()); |
| // Update the ids in the tokens. |
| uint32_t n_dict_0 = max_symbols[0].size(); |
| for (uint32_t i = kNumElements * n_dict_0; i < tokens.size(); ++i) { |
| tokens[i].first += n_dict_0; |
| } |
| // Update the max_symbols. |
| std::vector<uint32_t> max_symbols_new = max_symbols[0]; |
| max_symbols_new.insert(max_symbols_new.end(), max_symbols[1].begin(), |
| max_symbols[1].end()); |
| |
| // Decode the buffer. |
| std::vector<ANSCodes> codes; |
| GetCodesFromTokens(max_symbols_new, tokens, codes); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc[0].WriteBitstreamTo(bits)); |
| EXPECT_TRUE(VerifyBuffer(bits.data(), bits.size(), &codes, |
| /*bin_symbol=*/nullptr, /*asymbol=*/nullptr, |
| tokens)); |
| } |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| struct TokenSetInfo { |
| std::string name; |
| std::vector<Token::Type> types; |
| }; |
| |
| class TestANSBitCounts |
| : public testing::TestWithParam<std::tuple<uint32_t, TokenSetInfo>> {}; |
| |
| // This test only works in WP2_BITTRACE mode! |
| TEST_P(TestANSBitCounts, Simple) { |
| #if defined(WP2_BITTRACE) |
| const int kNumDicts = 3; |
| const std::vector<uint32_t> max_symbols(kNumDicts, 256); |
| |
| WP2MathInit(); |
| ANSInit(); |
| |
| const TokenSetInfo& info = std::get<1>(GetParam()); |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| UniformIntDistribution random(/*seed=*/std::get<0>(GetParam())); |
| std::vector<Token> tokens; |
| GenerateRandomEnc(/*n_elements=*/1000 * kTestFactor, info.types, |
| /*max_asymbol=*/0, kNumDicts, /*max_symbol=*/256, |
| /*max_range=*/3263, /*bin_symbol=*/nullptr, |
| /*asymbol=*/nullptr, &dicts, &random, &tokens, &enc); |
| std::vector<ANSCodes> codes(kNumDicts); |
| if (std::find(info.types.begin(), info.types.end(), Token::Symbol) != |
| info.types.end()) { |
| GetCodesFromTokens(max_symbols, tokens, codes); |
| } |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| for (const Token& tok : tokens) { |
| if (tok.type == Token::Bit) { |
| (void)dec.ReadBit(tok.second, kProbaMax, "bit"); |
| } else if (tok.type == Token::Bool) { |
| (void)dec.ReadBool("bool"); |
| } else if (tok.type == Token::Symbol) { |
| (void)dec.ReadSymbol(codes[tok.first].data(), ANS_LOG_TAB_SIZE, "symbol"); |
| } else if (tok.type == Token::Uniform) { |
| (void)dec.ReadUValue(tok.second, "U_value"); |
| } else if (tok.type == Token::Range) { |
| (void)dec.ReadRValue(tok.second, "R_value"); |
| } else if (tok.type == Token::MinMax) { |
| (void)dec.ReadRange(tok.second, tok.third, "minmax"); |
| } else if (tok.type == Token::Signed) { |
| (void)dec.ReadSUValue(tok.second, "S_value"); |
| } else { |
| EXPECT_FALSE(true); |
| } |
| } |
| EXPECT_WP2_OK(dec.GetStatus()); |
| const double bit_count = dec.GetBitCount(); |
| |
| // difference between final size and bit-trace: |
| const double error = (bits.size() - bit_count / 8.) / bits.size(); |
| if (kVerbose) { |
| printf("%s: %d elemt. Bits: actual: %d, reported: %.2lf (%.2lf%% off)\n", |
| info.name.c_str(), (int)tokens.size(), |
| (int)bits.size() * 8, bit_count, error * 100.f); |
| } |
| EXPECT_LE(std::abs(error), ANSDec::kBitCountAccuracy); |
| #endif |
| } |
| |
| INSTANTIATE_TEST_SUITE_P( |
| TestANSBitCountsInstantiation, TestANSBitCounts, |
| testing::Combine(testing::Values(kBaseSeed), |
| testing::ValuesIn(std::vector<TokenSetInfo>{ |
| {"Bit ", {Token::Bit}}, |
| {"Bool ", {Token::Bool}}, |
| {"Symbol ", {Token::Symbol}}, |
| {"Range ", {Token::Range}}, |
| {"MinMax ", {Token::MinMax}}, |
| {"Uniform ", {Token::Uniform}}, |
| {"Signed ", {Token::Signed}}, |
| {"All ", |
| {Token::Bit, Token::Symbol, Token::Uniform, |
| Token::Range}}}))); |
| |
| // Make sure ANS does not spin forever when stats have a tail. |
| TEST_F(TestANS, TestSpin) { |
| ANSDictionary dict; |
| EXPECT_WP2_OK(dict.Init(20)); |
| dict.RecordSymbol(0, 2 * ANS_TAB_SIZE); |
| for (uint32_t i = 1; i < 20; ++i) dict.RecordSymbol(i, 1); |
| // This should not spin. |
| EXPECT_WP2_OK(dict.ToCodingTable()); |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| class TestANSAdaptiveSymbol : public testing::TestWithParam<uint32_t> {}; |
| |
| TEST_P(TestANSAdaptiveSymbol, Simple) { |
| const uint32_t kNumSymbols = 1000 * kTestFactor; |
| UniformIntDistribution random(/*seed=*/GetParam()); |
| |
| WP2MathInit(); |
| ANSInit(); |
| |
| uint32_t rnd_range = (random.Get(0, 35) == 0) |
| ? (2u << APROBA_BITS) / APROBA_MAX_SYMBOL |
| : random.Get(1u, 30u); |
| uint32_t cumul[APROBA_MAX_SYMBOL + 1]; |
| uint32_t real_cumul[APROBA_MAX_SYMBOL] = {0}; |
| uint32_t sum = 0; |
| for (uint32_t& c : cumul) { |
| sum = c = random.Get(sum, sum + rnd_range + 1u); |
| rnd_range = rnd_range * 3 / 4; |
| } |
| |
| std::vector<uint8_t> syms(kNumSymbols); |
| for (uint8_t& sym : syms) { |
| const uint32_t p = random.Get(0u, sum); |
| uint32_t s = 0; |
| for (s = 0; s < APROBA_MAX_SYMBOL - 1; ++s) { |
| if (p < cumul[s + 1]) break; |
| } |
| sym = s; |
| ++real_cumul[s]; |
| } |
| |
| if (random.Get(0, 19) == 0) { |
| std::sort(syms.begin(), syms.end()); // we'll have to be very adaptive! |
| } |
| |
| // Encode |
| const uint32_t adapt_speed = random.Get(0u, APROBA_BITS - 1u); |
| ANSAdaptiveSymbol w_dict; |
| w_dict.SetAdaptationSpeed(ANSAdaptiveSymbol::Method::kConstant, adapt_speed); |
| w_dict.InitFromUniform(APROBA_MAX_SYMBOL); |
| ANSEnc enc; |
| for (const uint8_t& s : syms) enc.PutASymbol(s, &w_dict, "asym"); |
| EXPECT_WP2_OK(enc.AssembleToBitstream(true)); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| if (kVerbose) { |
| for (uint32_t s = 0; s < APROBA_MAX_SYMBOL; ++s) { |
| const float real_p = real_cumul[s] * 100.f / syms.size(); |
| const float dict_p = w_dict.GetProba(s) * 100.f; |
| printf( |
| "#%d theory:%.2f%% real:%.2f%% adapt-dict:%.2f%% -> error=%.1f%%\n", |
| s, (cumul[s + 1] - cumul[s]) * 100.f / sum, real_p, dict_p, |
| fabs(1. - dict_p / real_p) * 100.); |
| } |
| printf("%d symbols, size=%d\n", (int)syms.size(), (int)bits.size()); |
| } |
| |
| // Decode and check |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| ANSAdaptiveSymbol r_dict; |
| r_dict.SetAdaptationSpeed(ANSAdaptiveSymbol::Method::kConstant, adapt_speed); |
| r_dict.InitFromUniform(APROBA_MAX_SYMBOL); |
| for (const uint8_t& sym : syms) { |
| const uint32_t s = dec.ReadASymbol(&r_dict, "asym"); |
| EXPECT_EQ(s, sym) << ": expected symbol " << sym << ", got " << s; |
| } |
| |
| // Match ending probabilities |
| for (uint32_t s = 0; s < APROBA_MAX_SYMBOL; ++s) { |
| EXPECT_EQ(w_dict.GetProba(s), r_dict.GetProba(s)) |
| << "Symbol #" << s << ": Final proba error: expected " |
| << w_dict.GetProba(s) << ", got " << r_dict.GetProba(s); |
| } |
| } |
| |
| TEST_P(TestANSAdaptiveSymbol, VsDictionary) { |
| // Disable this test with WP2_ENC_DEC_MATCH as it considerably increases the |
| // size of the stream (by adding hashes) and therefore makes the thresholds |
| // useless. |
| #if !defined(WP2_ENC_DEC_MATCH) |
| UniformIntDistribution random(/*seed=*/GetParam()); |
| |
| // generate a pdf[] |
| uint32_t len = random.Get(1000u, 1999u) * kTestFactor; |
| uint32_t pdf[APROBA_MAX_SYMBOL]; |
| for (uint32_t& i : pdf) i = random.Get(0u, len / 16u - 1u); |
| const uint32_t speed = random.Get(0u, 0x3fffu); // *must* be < 32768 in SSE2 |
| |
| // remove some symbols |
| uint32_t nsym = |
| (random.Get(0, 19) == 0) ? APROBA_MAX_SYMBOL : random.Get(6u, 15u); |
| while (nsym++ < APROBA_MAX_SYMBOL) { |
| pdf[random.Get(0u, APROBA_MAX_SYMBOL - 1u)] = 0; |
| } |
| |
| len = 0; |
| for (uint32_t c : pdf) len += c; |
| |
| // generate a message |
| std::vector<uint8_t> msg(len); |
| for (uint32_t s = 0, k = 0; s < APROBA_MAX_SYMBOL; ++s) { |
| for (uint32_t n = 0; n < pdf[s]; ++n) msg[k++] = (uint8_t)s; |
| } |
| // shuffle the message |
| Shuffle(msg.begin(), msg.end(), /*seed=*/GetParam()); |
| |
| // compute the theoretical bit length |
| float total = 32.; // ~32 bits for ANS padding |
| for (uint32_t v : pdf) { |
| if (v) total -= log2(1. * v / len) * v; |
| } |
| |
| WP2MathInit(); |
| ANSInit(); |
| |
| // populate dictionary stats |
| ANSDictionaries dicts; |
| EXPECT_WP2_OK(dicts.Add(APROBA_MAX_SYMBOL)); |
| ANSDictionary* const dict = dicts.back(); |
| for (uint8_t s : msg) dict->RecordSymbol(s); |
| |
| ANSEnc enc; |
| ANSAdaptiveSymbol asym; |
| asym.SetAdaptationSpeed(ANSAdaptiveSymbol::Method::kConstant, speed); |
| EXPECT_WP2_OK(asym.InitFromCounts(pdf, APROBA_MAX_SYMBOL)); |
| |
| for (uint8_t s : msg) enc.PutASymbol(s, &asym, "test1"); |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| const float size1 = 8.f * enc.GetBitstreamSize(); |
| |
| enc.Reset(); |
| EXPECT_WP2_OK(dict->ToCodingTable()); |
| for (uint8_t s : msg) enc.PutSymbol(s, *dict, "test2"); |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| const float size2 = 8.f * enc.GetBitstreamSize(); |
| |
| // tolerance of ASym vs Dictionary |
| const float kTolerance[16] = { |
| 350.f, 150.f, 100.f, 50.f, 15.f, 7.2f, 3.5f, 2.8f, 2.5f, 2.2f, |
| // For some reasons, starting at speed=10, it becomes more unstable. |
| 6.0f, 20.0f, 50.0f, 70.f, 50.f, |
| // Speed=15 is really non-adaptive, though, with low error. |
| 0.3f}; |
| const float kTolerance2 = 0.2f; // dictionary tolerance (quite low!) |
| const float err1 = fabs(100. * (size1 - size2) / total); |
| const float err2 = fabs(100. * (size2 - total) / total); |
| const int lspeed = 15 - (int)log2(1 + speed); |
| EXPECT_LT(err1, kTolerance[lspeed]); |
| EXPECT_LT(err2, kTolerance2); |
| |
| if (kVerbose || err1 > kTolerance[lspeed] || err2 > kTolerance2) { |
| printf("#%d symbols, expected bits=%.1f ", len, total); |
| printf("final bits1 = %.1f bits2=%.1f [speed=%d lspeed=%d %.3f]\n", |
| size1, size2, speed, lspeed, err1); |
| } |
| |
| // Write one symbol with sym and the next with asym. |
| dict->ResetCounts(); |
| for (uint32_t i = 0; i < msg.size(); i += 2) dict->RecordSymbol(msg[i]); |
| EXPECT_WP2_OK(dict->ToCodingTable()); |
| |
| enc.Reset(); |
| EXPECT_WP2_OK(asym.InitFromCounts(pdf, APROBA_MAX_SYMBOL)); |
| |
| for (uint32_t i = 0; i < msg.size(); ++i) { |
| if (i % 2 == 0) { |
| enc.PutSymbol(msg[i], *dict, "sym"); |
| } else { |
| enc.PutASymbol(msg[i], &asym, "asym"); |
| } |
| } |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| const float size = 8.f * enc.GetBitstreamSize(); |
| const float cost = enc.GetCost(dicts); |
| |
| // tolerance of ASym vs Dictionary |
| const float err = (size == 0.f) ? 0.f : fabs(100. * (size - cost) / size); |
| constexpr float kTolerance3 = 1.05f; |
| EXPECT_LT(err, kTolerance3); |
| |
| if (kVerbose || err > kTolerance3) { |
| printf("#%d symbols, expected cost=%.1f ", len, cost); |
| printf("final bits = %.1f [speed=%d lspeed=%d %.3f]\n", |
| size, speed, lspeed, err); |
| } |
| #endif // WP2_ENC_DEC_MATCH |
| } |
| |
| INSTANTIATE_TEST_SUITE_P(TestANSAdaptiveSymbolInstantiation, |
| TestANSAdaptiveSymbol, |
| testing::Range(kBaseSeed, |
| kBaseSeed + 30 * kTestFactor)); |
| |
| //------------------------------------------------------------------------------ |
| |
| void TestUniformProbability(ANSAdaptiveSymbol* const asym, |
| uint16_t max_symbol) { |
| uint32_t symbol_to_num_occurrences[APROBA_MAX_SYMBOL]{0u}; |
| for (uint32_t proba = 0; proba < APROBA_MAX; ++proba) { |
| const uint16_t symbol = asym->GetSymbol(proba).symbol; |
| EXPECT_LT(symbol, max_symbol); |
| ++symbol_to_num_occurrences[symbol]; |
| } |
| // Make sure they have almost equal probabilities. |
| uint32_t num_occurrences1 = symbol_to_num_occurrences[0]; |
| uint32_t num_occurrences2 = symbol_to_num_occurrences[0]; |
| for (uint16_t symbol = 1; symbol < max_symbol; ++symbol) { |
| if (symbol_to_num_occurrences[symbol] != num_occurrences1) { |
| if (num_occurrences1 == num_occurrences2) { |
| num_occurrences2 = symbol_to_num_occurrences[symbol]; |
| EXPECT_EQ(num_occurrences1 + 1, num_occurrences2); |
| } else { |
| EXPECT_EQ(symbol_to_num_occurrences[symbol], num_occurrences2); |
| } |
| } |
| } |
| } |
| |
| TEST_F(TestANS, TestANSAdaptiveSymbol) { |
| ANSAdaptiveSymbol asym; |
| for (uint16_t max_symbol = 1; max_symbol < APROBA_MAX_SYMBOL; ++max_symbol) { |
| asym.SetAdaptationSpeed(ANSAdaptiveSymbol::Method::kAOM); |
| asym.InitFromUniform(max_symbol); |
| TestUniformProbability(&asym, max_symbol); |
| } |
| } |
| |
| // Tests that symbols can be initialized from CDF's. |
| TEST_F(TestANS, TestANSAdaptiveSymbolFromCDF) { |
| UniformIntDistribution random; |
| for (uint16_t max_symbol = 1; max_symbol < APROBA_MAX_SYMBOL; ++max_symbol) { |
| // Get random probabilities. |
| uint16_t counts[APROBA_MAX_SYMBOL] = {0u}; |
| uint32_t sum = 0u; |
| for (uint32_t s = 0; s < max_symbol; ++s) { |
| counts[s] = random.Get((s == 0) ? 1u : 0u, 100u); |
| sum += counts[s]; |
| } |
| // Deduce a cdf. |
| uint16_t cdf[APROBA_MAX_SYMBOL] = {0}; |
| for (uint32_t s = 1; s < max_symbol; ++s) { |
| cdf[s] = cdf[s - 1] + counts[s - 1]; |
| } |
| ANSAdaptiveSymbol asym; |
| EXPECT_WP2_OK(asym.InitFromCDF(cdf, max_symbol, sum)); |
| // Verify probabilities are the same. |
| for (uint32_t s = 0; s < max_symbol; ++s) { |
| EXPECT_NEAR((double)asym.GetProba(s), (double)counts[s] / sum, 0.001); |
| } |
| } |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| class TestANSMapping : public testing::TestWithParam<uint32_t> {}; |
| |
| TEST_P(TestANSMapping, Simple) { |
| ANSInit(); |
| UniformIntDistribution random(kBaseSeed); |
| const uint16_t range = random.Get<uint16_t>(1, 100); |
| std::vector<OptimizeArrayStorageStat> stats(range); |
| std::vector<uint16_t> mapping_ini(range); |
| Vector_u16 mapping_res; |
| const uint32_t pattern_type = GetParam(); |
| for (uint32_t n = 1; n < range; ++n) { |
| if (pattern_type == 0) { |
| mapping_ini[0] = random.Get(0u, range - n); |
| for (uint32_t j = 1; j < n; ++j) { |
| mapping_ini[j] = |
| random.Get<uint16_t>(mapping_ini[j - 1] + 1, range - (n - j)); |
| } |
| } else { |
| for (uint32_t j = 0; j < n / 2; ++j) mapping_ini[j] = j; |
| for (uint32_t j = n / 2; j < n; ++j) mapping_ini[j] = j + range - n; |
| } |
| for (uint32_t j = 1; j < n; ++j) { |
| ASSERT_LT(mapping_ini[j - 1], mapping_ini[j]) |
| << "pos: " << j << " size:" << n << " range:" << range; |
| ASSERT_LT(mapping_ini[j], range) << "pos: " << j; |
| } |
| // Store the mapping to an encoder. |
| ANSEnc enc; |
| StoreMapping(mapping_ini.data(), n, range, stats.data(), &enc); |
| enc.PutRange(31, 0, 263, "verif-bit"); |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| // Read back the mapping. |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| EXPECT_TRUE(mapping_res.resize(n)); |
| EXPECT_WP2_OK(LoadMapping(&dec, n, range, mapping_res.data())); |
| for (uint32_t i = 0; i < n; ++i) { |
| EXPECT_EQ(mapping_ini[i], mapping_res[i]); |
| } |
| EXPECT_EQ(dec.ReadRange(0, 263, "verif-bit"), 31); |
| } |
| } |
| |
| INSTANTIATE_TEST_SUITE_P(TestANSMappingInstantiation, TestANSMapping, |
| testing::Range(0u, 2u)); |
| |
| //------------------------------------------------------------------------------ |
| |
| TEST_F(TestANS, DictionarySymbolCost) { |
| ANSDictionary dict; |
| EXPECT_NE(dict.Init(0), WP2_STATUS_OK); |
| EXPECT_WP2_OK(dict.Init(/*max_symbol=*/4)); |
| |
| EXPECT_EQ(dict.SymbolCost(0), 0); |
| EXPECT_EQ(dict.SymbolCost(1), 0); |
| EXPECT_EQ(dict.SymbolCost(2), 0); |
| EXPECT_EQ(dict.SymbolCost(3), 0); |
| |
| dict.RecordSymbol(1, 2); |
| |
| EXPECT_EQ(dict.SymbolCost(0), 0); |
| EXPECT_EQ(dict.SymbolCost(1), 0); |
| EXPECT_EQ(dict.SymbolCost(2), 0); |
| EXPECT_EQ(dict.SymbolCost(3), 0); |
| |
| dict.RecordSymbol(0, 2); |
| |
| EXPECT_EQ(dict.SymbolCost(0), 1); |
| EXPECT_EQ(dict.SymbolCost(1), 1); |
| EXPECT_EQ(dict.SymbolCost(2), 0); |
| EXPECT_EQ(dict.SymbolCost(3), 0); |
| |
| dict.RecordSymbol(3); |
| |
| EXPECT_NEAR(dict.SymbolCost(0), 1.32, 0.01); |
| EXPECT_NEAR(dict.SymbolCost(1), 1.32, 0.01); |
| EXPECT_EQ(dict.SymbolCost(2), 0); |
| EXPECT_NEAR(dict.SymbolCost(3), 2.32, 0.01); |
| |
| // Test max_symbol. |
| EXPECT_EQ(dict.SymbolCost(0, /*max_symbol=*/2), 1); |
| EXPECT_EQ(dict.SymbolCost(1, /*max_symbol=*/2), 1); |
| EXPECT_EQ(dict.SymbolCost(2, /*max_symbol=*/2), 0); |
| |
| EXPECT_EQ(dict.SymbolCost(0, /*max_symbol=*/1), 1); |
| EXPECT_EQ(dict.SymbolCost(1, /*max_symbol=*/1), 1); |
| |
| EXPECT_EQ(dict.SymbolCost(0, /*max_symbol=*/0), 0); |
| |
| // Test quantized dictionary. |
| dict.ResetCounts(); |
| dict.RecordSymbol(0, 1); |
| dict.RecordSymbol(2, 20); |
| for (uint32_t i = 0; i < 2; ++i) { |
| Vector_u32 quantized_counts; |
| ASSERT_TRUE(quantized_counts.resize(4)); |
| quantized_counts[0] = 2; |
| quantized_counts[1] = 0; |
| quantized_counts[2] = 2; |
| quantized_counts[3] = 0; |
| EXPECT_WP2_OK(dict.SetQuantizedCounts(quantized_counts.data())); |
| EXPECT_EQ(dict.SymbolCost(0), 1); |
| EXPECT_EQ(dict.SymbolCost(1), 0); |
| EXPECT_EQ(dict.SymbolCost(2), 1); |
| EXPECT_EQ(dict.SymbolCost(3), 0); |
| // Even if no symbol has been recorded, the cost needs to be the one from |
| // the quantized counts. |
| dict.ResetCounts(); |
| } |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| TEST_F(TestANS, LargeRange) { |
| // List of <value, min, max> |
| // Some edge cases with range <= kANSMaxRange. |
| std::vector<std::tuple<uint32_t, uint32_t, uint32_t>> test_cases = { |
| {100, 0, kANSMaxRange - 3}, |
| {100, 0, kANSMaxRange - 2}, |
| {100, 0, kANSMaxRange - 1}, |
| {kANSMaxRange - 3, 0, kANSMaxRange - 3}, |
| {kANSMaxRange - 2, 0, kANSMaxRange - 2}, |
| {kANSMaxRange - 1, 0, kANSMaxRange - 1}, |
| }; |
| // More random test cases. |
| UniformIntDistribution random(kBaseSeed); |
| for (uint32_t i = 0; i < 1; ++i) { |
| const uint32_t max = random.Get(1u, kANSMaxRange * 10); |
| const uint32_t min = random.Get(0u, max - 1); |
| const uint32_t value = random.Get(min, max); |
| test_cases.push_back(std::make_tuple(value, min, max)); |
| } |
| |
| // Write. |
| ANSEnc enc; |
| for (const auto& t : test_cases) { |
| const uint32_t value = std::get<0>(t); |
| const uint32_t min = std::get<1>(t); |
| const uint32_t max = std::get<2>(t); |
| SCOPED_TRACE(SPrintf("value: %d min: %d max: %d", value, min, max)); |
| EXPECT_EQ(value, PutLargeRange(value, min, max, &enc, "label")); |
| |
| // Check cost. |
| ANSEncCounter counter; |
| PutLargeRange(value, min, max, &counter, "label"); |
| const uint32_t range = max - min + 1; |
| EXPECT_NEAR(counter.GetCost(), std::log2(range), 0.01); |
| } |
| |
| // Read. |
| ASSERT_WP2_OK(enc.AssembleToBitstream(true)); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| for (const auto& t : test_cases) { |
| const uint32_t value = std::get<0>(t); |
| const uint32_t min = std::get<1>(t); |
| const uint32_t max = std::get<2>(t); |
| SCOPED_TRACE(SPrintf("value: %d min: %d max: %d", value, min, max)); |
| EXPECT_EQ(value, ReadLargeRange(min, max, &dec, "label")); |
| } |
| } |
| |
| } // namespace |
| } // namespace WP2 |