| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // Test SymbolWriter/SymbolReader. |
| |
| #include <algorithm> |
| #include <limits> |
| #include <memory> |
| |
| #include "examples/example_utils.h" |
| #include "include/helpers.h" |
| #include "src/common/symbols.h" |
| #include "src/dec/symbols_dec.h" |
| #include "src/enc/symbols_enc.h" |
| #include "src/utils/random.h" |
| #include "src/utils/data_source.h" |
| |
| namespace WP2 { |
| namespace { |
| |
| //------------------------------------------------------------------------------ |
| |
| using StorageMethod = SymbolsInfo::StorageMethod; |
| |
| // Expose SymbolsInfo::Init() for test purposes. |
| class SymbolsInfoTest : public SymbolsInfo { |
| public: |
| using SymbolsInfo::SetInfo; |
| |
| bool operator==(const SymbolsInfoTest& info) const { |
| if (Size() != info.Size()) return false; |
| for (uint32_t sym = 0; sym < Size(); ++sym) { |
| if (!Equal(GetSymbolInfo(sym), info.GetSymbolInfo(sym))) return false; |
| for (uint32_t cluster = 0; cluster < NumClusters(sym); ++cluster) { |
| if (!Equal(GetClusterInfo(sym, cluster), |
| info.GetClusterInfo(sym, cluster))) { |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| |
| protected: |
| bool Equal(const SymbolInfo& a, const SymbolInfo& b) const { |
| return a.num_clusters == b.num_clusters && |
| a.storage_method == b.storage_method && |
| a.min == b.min && a.max == b.max; |
| } |
| |
| bool Equal(const ClusterInfo& a, const ClusterInfo& b) const { |
| return (a.p0 == b.p0 && a.p1 == b.p1 && |
| a.min == b.min && a.max == b.max && |
| a.cdf == b.cdf); |
| } |
| }; |
| |
| //------------------------------------------------------------------------------ |
| |
| // Initializes 'num_symbol_types' symbols of type auto. |
| void SetupInfo(uint32_t num_symbol_types, int32_t min_symbol, |
| int32_t max_symbol, uint32_t num_clusters, |
| UniformIntDistribution* random, SymbolsInfoTest* info) { |
| for (uint32_t symbol = 0; symbol < num_symbol_types; ++symbol) { |
| const int32_t min = random->Get(min_symbol, max_symbol); |
| const int32_t max = random->Get(min, max_symbol); |
| info->SetInfo(symbol, min, max, num_clusters, StorageMethod::kAuto); |
| } |
| } |
| |
| class DataGenerator { |
| public: |
| struct Symbol { |
| uint32_t s; |
| int32_t value; |
| uint32_t cluster; |
| bool can_be_negative; |
| }; |
| // Enum to hint at the distribution of a symbol. |
| enum class kHint { |
| kUniform, // uniform distribution |
| kDict, // random distribution |
| kPrefixCode, // exponential distribution |
| kTrivial // unique value |
| }; |
| DataGenerator(const SymbolsInfoTest& symbols_info, |
| UniformIntDistribution* random) |
| : symbols_info_(&symbols_info), random_(random) {} |
| |
| // Adds data consisting of 'num_symbols', for the given symbol 'sym', |
| // in a random range chosen in [0, 'max_range']. |
| void AddData(uint32_t sym, uint32_t num_symbols, kHint distribution, |
| SymbolRecorder* const recorder, std::vector<Symbol>* symbols) { |
| ANSEncNoop enc; |
| const uint32_t range = symbols_info_->GetMaxRange(sym); |
| const uint32_t num_symbol_types = symbols_info_->Size(); |
| counts_.resize(num_symbol_types); |
| const uint32_t num_clusters = symbols_info_->NumClusters(sym); |
| counts_[sym].resize(num_clusters); |
| for (uint32_t c = 0; c < num_clusters; ++c) { |
| counts_[sym][c].resize(range, 0); |
| } |
| |
| // Create some biased statistics by repeating values a certain number of |
| // times and then uniformly picking from the vector. |
| std::vector<uint32_t> probas; |
| int32_t trivial_value = std::numeric_limits<int32_t>::max(); |
| switch (distribution) { |
| case kHint::kDict: |
| probas.reserve(num_symbol_types); |
| for (uint32_t r = 0; r < range; ++r) { |
| const uint32_t tmp = random_->Get(0u, 1000u); |
| for (uint32_t j = 0; j <= tmp; ++j) probas.push_back(r); |
| } |
| break; |
| case kHint::kPrefixCode: |
| probas.reserve(num_symbol_types); |
| for (uint32_t r = 0; r < range; ++r) { |
| const float rand_max = 1000.f * std::exp(-(float)r); |
| const uint32_t tmp = |
| random_->Get(0u, std::max((uint32_t)rand_max, 1u)); |
| for (uint32_t j = 0; j <= tmp; ++j) probas.push_back(r); |
| } |
| break; |
| case kHint::kTrivial: |
| trivial_value = random_->Get(symbols_info_->Min(sym, /*cluster=*/0), |
| symbols_info_->Max(sym, /*cluster=*/0)); |
| break; |
| default: |
| break; |
| } |
| |
| // Create some stream using the biased statistics. |
| const bool can_be_negative = (symbols_info_->Min(sym, /*cluster=*/0) < 0); |
| bool has_at_least_one_negative = false; |
| for (uint32_t i = 0; i < num_symbols; ++i) { |
| const uint32_t cluster = random_->Get(0u, num_clusters - 1); |
| int32_t value = std::numeric_limits<int32_t>::max(); |
| switch (distribution) { |
| case kHint::kDict: |
| case kHint::kPrefixCode: |
| // Other symbols have biased data for which a dictionary will be more |
| // efficient. |
| value = probas[random_->Get(0, (int)probas.size() - 1)]; |
| break; |
| case kHint::kUniform: |
| // Getting uniform probabilities for the symbol will force it to be |
| // stored as a range by the SymbolWriter. |
| value = random_->Get(0u, range - 1); |
| break; |
| case kHint::kTrivial: |
| value = trivial_value; |
| break; |
| } |
| Symbol s; |
| s.s = sym; |
| s.value = value + symbols_info_->Min(sym, cluster); |
| has_at_least_one_negative |= (s.value < 0); |
| s.can_be_negative = can_be_negative; |
| s.cluster = cluster; |
| symbols->push_back(s); |
| ++(counts_)[sym][cluster][value]; |
| recorder->Process(sym, cluster, s.value, "label", &enc); |
| } |
| // Make sure there is at least one negative value if asked. |
| EXPECT_TRUE(!can_be_negative || has_at_least_one_negative); |
| } |
| |
| void WriteHeader(const SymbolsInfoTest& info, const SymbolRecorder& recorder, |
| uint32_t* max_nnz, ANSDictionaries* const dicts, |
| ANSEnc* const enc, SymbolWriter* const sw) { |
| ASSERT_WP2_OK(sw->Init(info, /*effort=*/5)); |
| ASSERT_WP2_OK(sw->Allocate()); |
| // Get the maximum number of non-zero values by aggregating over all |
| // clusters. |
| *max_nnz = 0; |
| for (uint32_t s = 0; s < symbols_info_->Size(); ++s) { |
| if (counts_[s].empty()) continue; |
| for (uint32_t c = 0; c < symbols_info_->NumClusters(s); ++c) { |
| *max_nnz = std::max(*max_nnz, std::accumulate(counts_[s][c].begin(), |
| counts_[s][c].end(), 0u)); |
| } |
| } |
| // Write the headers. |
| for (uint32_t s = 0; s < symbols_info_->Size(); ++s) { |
| for (uint32_t c = 0; c < symbols_info_->NumClusters(s); ++c) { |
| ASSERT_WP2_OK( |
| sw->WriteHeader(s, c, *max_nnz, recorder, "counts", enc, dicts)); |
| } |
| } |
| } |
| |
| void ReadHeader(const SymbolsInfoTest& info, uint32_t max_nnz, |
| ANSDec* const dec, SymbolReader* const sr) { |
| ASSERT_WP2_OK(sr->Init(info, dec)); |
| ASSERT_WP2_OK(sr->Allocate()); |
| for (uint32_t s = 0; s < symbols_info_->Size(); ++s) { |
| for (uint32_t c = 0; c < symbols_info_->NumClusters(s); ++c) { |
| ASSERT_WP2_OK(sr->ReadHeader(s, c, max_nnz, "counts")); |
| } |
| } |
| } |
| |
| private: |
| const SymbolsInfoTest* const symbols_info_; |
| UniformIntDistribution* random_; |
| // Per symbol, per cluster, per value. |
| std::vector<std::vector<std::vector<uint32_t>>> counts_; |
| }; |
| |
| constexpr std::array<DataGenerator::kHint, 4> kDistributions = { |
| DataGenerator::kHint::kUniform, DataGenerator::kHint::kDict, |
| DataGenerator::kHint::kPrefixCode, DataGenerator::kHint::kTrivial}; |
| |
| //------------------------------------------------------------------------------ |
| |
| TEST(SymbolsInfo, BasicTest) { |
| SymbolsInfoTest info; |
| info.SetInfo(/*sym=*/0, /*min=*/-9, /*max=*/9, /*num_clusters=*/2, |
| StorageMethod::kAuto); |
| EXPECT_WP2_OK( |
| info.SetMinMax(/*sym=*/0, /*cluster=*/1, /*min=*/-4, /*max=*/4)); |
| info.SetInfo(/*sym=*/1, /*min=*/0, /*max=*/1, /*num_clusters=*/3, |
| StorageMethod::kAdaptiveBit); |
| EXPECT_WP2_OK(info.SetStartingProba(/*sym=*/1, /*cluster=*/0, 3, 12)); |
| EXPECT_WP2_OK(info.SetStartingProba(/*sym=*/1, /*cluster=*/2, 8, 2)); |
| info.SetInfo(/*sym=*/3, /*min=*/0, /*max=*/3, |
| /*num_clusters=*/3, StorageMethod::kAuto); |
| EXPECT_WP2_OK( |
| info.SetMinMax(/*sym=*/3, /*cluster=*/2, /*min=*/0, /*max=*/19)); |
| |
| // Number of symbol. Even though we didn't set any info for symbol '2', it is |
| // assumed to exist. |
| EXPECT_EQ(info.Size(), 4u); |
| EXPECT_EQ(info.GetMaxRange(), 19u); |
| EXPECT_EQ(info.MaxRangeSum(), 41u); |
| EXPECT_EQ(info.RangeSum(), 62u); |
| |
| EXPECT_EQ(info.Range(0, /*cluster=*/0), 19u); |
| EXPECT_EQ(info.Range(0, /*cluster=*/1), 9u); |
| EXPECT_EQ(info.GetMaxRange(0), 19u); |
| EXPECT_EQ(info.Range(1, /*cluster=*/0), 2u); |
| EXPECT_EQ(info.Range(1, /*cluster=*/1), 2u); |
| EXPECT_EQ(info.Range(1, /*cluster=*/2), 2u); |
| EXPECT_EQ(info.GetMaxRange(1), 2u); |
| EXPECT_EQ(info.GetMaxRange(2), 0u); |
| EXPECT_EQ(info.Range(3, /*cluster=*/0), 4u); |
| EXPECT_EQ(info.Range(3, /*cluster=*/1), 4u); |
| EXPECT_EQ(info.Range(3, /*cluster=*/2), 20u); |
| EXPECT_EQ(info.GetMaxRange(3), 20u); |
| |
| EXPECT_EQ(info.NumClusters(0), 2u); |
| EXPECT_EQ(info.NumClusters(1), 3u); |
| EXPECT_EQ(info.NumClusters(2), 0u); |
| EXPECT_EQ(info.NumClusters(3), 3u); |
| |
| EXPECT_EQ(info.Method(0), StorageMethod::kAuto); |
| EXPECT_EQ(info.Method(1), StorageMethod::kAdaptiveBit); |
| EXPECT_EQ(info.Method(2), StorageMethod::kAuto); |
| EXPECT_EQ(info.Method(3), StorageMethod::kAuto); |
| |
| SymbolsInfoTest copy; |
| ASSERT_WP2_OK(copy.CopyFrom(info)); |
| EXPECT_TRUE(copy == info); |
| copy.SetInfo(/*sym=*/2, /*min=*/0, /*max=*/4, /*num_clusters=*/1, |
| StorageMethod::kAuto); |
| EXPECT_FALSE(copy == info); |
| EXPECT_EQ(copy.Size(), 4u); |
| EXPECT_EQ(copy.GetMaxRange(), 19u); |
| EXPECT_EQ(copy.MaxRangeSum(), 46u); |
| EXPECT_EQ(copy.RangeSum(), 67u); |
| |
| EXPECT_EQ(info.StartingProbaP0(/*sym=*/1, /*cluster=*/0), 3u); |
| EXPECT_EQ(info.StartingProbaP1(/*sym=*/1, /*cluster=*/0), 12u); |
| EXPECT_EQ(info.StartingProbaP0(/*sym=*/1, /*cluster=*/1), 1u); |
| EXPECT_EQ(info.StartingProbaP1(/*sym=*/1, /*cluster=*/1), 1u); |
| EXPECT_EQ(info.StartingProbaP0(/*sym=*/1, /*cluster=*/2), 8u); |
| EXPECT_EQ(info.StartingProbaP1(/*sym=*/1, /*cluster=*/2), 2u); |
| } |
| |
| TEST(SymbolsInfo, LossessSymbolsInfo) { |
| WP2L::LosslessSymbolsInfo info; |
| info.Init(/*num_pixels=*/1000, /*has_alpha=*/true, WP2_Argb_32); |
| info.SetNumClusters(3); |
| info.SetCacheRange(1 << 2); |
| EXPECT_EQ(info.Range(WP2L::kSymbolA, /*cluster=*/0), (uint32_t)(1 << 8)); |
| EXPECT_EQ(info.NumClusters(WP2L::kSymbolA), 3u); |
| |
| EXPECT_EQ(info.Range(WP2L::kSymbolCache, /*cluster=*/0), (uint32_t)(1 << 2)); |
| EXPECT_EQ(info.NumClusters(WP2L::kSymbolCache), 3u); |
| |
| info.SetCacheRange(1 << 4); |
| EXPECT_EQ(info.Range(WP2L::kSymbolCache, /*cluster=*/0), (uint32_t)(1 << 4)); |
| EXPECT_EQ(info.NumClusters(WP2L::kSymbolCache), 3u); |
| |
| info.SetNumClusters(2); |
| EXPECT_EQ(info.NumClusters(WP2L::kSymbolA), 2u); |
| EXPECT_EQ(info.NumClusters(WP2L::kSymbolCache), 2u); |
| } |
| |
| // Helper function for the following two tests. |
| void VerifyProcessRead(const SymbolsInfoTest& info, uint32_t num_symbols, |
| uint32_t num_symbol_types, |
| DataGenerator* const generator) { |
| std::vector<DataGenerator::Symbol> symbols; |
| SymbolRecorder recorder; |
| ASSERT_WP2_OK(recorder.Allocate(info, num_symbols)); |
| |
| for (uint32_t s = 0; s < num_symbol_types; ++s) { |
| generator->AddData( |
| s, num_symbols, |
| kDistributions[s * kDistributions.size() / num_symbol_types], &recorder, |
| &symbols); |
| } |
| Shuffle(symbols.begin(), symbols.end(), 0); |
| |
| // Create the symbol writer. |
| std::unique_ptr<WP2::SymbolWriter> sw(new (WP2Allocable::nothrow) |
| WP2::SymbolWriter); |
| ASSERT_TRUE(sw != nullptr); |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| uint32_t max_nnz = 0; |
| generator->WriteHeader(info, recorder, &max_nnz, &dicts, &enc, sw.get()); |
| const float header_cost = enc.GetCost(); |
| |
| // Write the symbols. |
| float enc_cost = 0.f; |
| for (const DataGenerator::Symbol& sym : symbols) { |
| sw->ProcessWithCost(sym.s, sym.cluster, sym.value, "sym", &enc, &enc_cost); |
| } |
| EXPECT_WP2_OK(enc.AssembleToBitstream(/*clear_tokens=*/false)); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| EXPECT_EQ(bits.size(), enc.GetBitstreamSize()); |
| |
| // Decode the symbol reader statistics. |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| SymbolReader sr; |
| generator->ReadHeader(info, max_nnz, &dec, &sr); |
| |
| // Verify the symbols do match the original data. |
| float dec_cost = enc_cost; |
| float* cost_ptr = nullptr; |
| #if defined(WP2_BITTRACE) |
| dec_cost = 0.; |
| cost_ptr = &dec_cost; |
| #endif |
| for (const DataGenerator::Symbol& sym : symbols) { |
| ASSERT_EQ(sr.Read(sym.s, sym.cluster, "sym", cost_ptr), sym.value); |
| } |
| const float symbol_cost = enc.GetCost() - header_cost; |
| EXPECT_NEAR(symbol_cost, enc_cost, 0.2f); |
| EXPECT_NEAR(symbol_cost, dec_cost, 0.2f); |
| ASSERT_WP2_OK(dec.GetStatus()); |
| } |
| |
| // Test basic functionality of the SymbolWriter/SymbolReader. |
| TEST(SymbolsTest, Basic) { |
| static constexpr uint32_t kNumSymbol = 1000u; |
| static constexpr uint32_t kNumSymbolTypes = 5u; |
| static constexpr int32_t kMinSymbol = -10; |
| static constexpr int32_t kMaxSymbol = 100; |
| static constexpr uint32_t kNumClusters = 10u; |
| UniformIntDistribution random(/*seed=*/0); |
| |
| static_assert(kNumSymbolTypes <= kSymbolNumMax, "Need less symbols."); |
| |
| // Generate the data. |
| SymbolsInfoTest info; |
| SetupInfo(kNumSymbolTypes, kMinSymbol, kMaxSymbol, kNumClusters, &random, |
| &info); |
| DataGenerator generator(info, &random); |
| |
| VerifyProcessRead(info, kNumSymbol, kNumSymbolTypes, &generator); |
| } |
| |
| // Test the SymbolWriter/SymbolReader when CDF is given. |
| TEST(SymbolsTest, CDF) { |
| ANSInit(); |
| static constexpr uint32_t kNumSymbol = 10000u; |
| static constexpr uint32_t kNumSymbolTypes = 2u; |
| static constexpr int32_t kMinSymbol = 0; |
| static constexpr int32_t kMaxSymbol = APROBA_MAX_SYMBOL; |
| static constexpr uint32_t kNumClusters = 10u; |
| static constexpr uint32_t kMaxCount = 1000u; |
| UniformIntDistribution random(/*seed=*/0); |
| |
| static_assert(kNumSymbolTypes <= kSymbolNumMax, "Need less symbols."); |
| |
| // Generate the data. |
| SymbolsInfoTest info; |
| SetupInfo(kNumSymbolTypes, kMinSymbol, kMaxSymbol, kNumClusters, &random, |
| &info); |
| uint16_t cdfs[kNumSymbolTypes][APROBA_MAX_SYMBOL]; |
| for (uint32_t i = 0; i < kNumSymbolTypes; ++i) { |
| cdfs[i][0] = 0; |
| for (uint32_t j = 1; j < APROBA_MAX_SYMBOL; ++j) { |
| const uint32_t c = random.Get(1u, kMaxCount); |
| cdfs[i][j] = cdfs[i][j - 1] + c; |
| } |
| // Normalize. |
| const uint32_t sum = |
| cdfs[i][APROBA_MAX_SYMBOL - 1] + random.Get(1u, kMaxCount); |
| for (uint32_t j = 1; j < APROBA_MAX_SYMBOL; ++j) { |
| cdfs[i][j] = cdfs[i][j] * ANS_MAX_SYMBOLS / sum; |
| } |
| } |
| for (uint32_t i = 0; i < kNumSymbolTypes; ++i) { |
| info.SetInfo(i, /*min=*/0, /*max=*/info.Range(i, /*cluster=*/0) - 1, |
| info.NumClusters(i), StorageMethod::kAdaptiveSym); |
| EXPECT_WP2_OK(info.SetInitialCDF(i, /*cluster=*/0, |
| cdfs[i], ANS_MAX_SYMBOLS)); |
| } |
| DataGenerator generator(info, &random); |
| |
| VerifyProcessRead(info, kNumSymbol, kNumSymbolTypes, &generator); |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| // Test the SymbolWriter/SymbolReader for uniform probability. |
| TEST(SymbolsTest, Uniform) { |
| static constexpr uint32_t kNumSymbol = 10000u; |
| static constexpr int32_t kMinSymbol = -23; |
| static constexpr int32_t kMaxSymbol = 1000; |
| UniformIntDistribution random(/*seed=*/0); |
| |
| // Generate the data. |
| SymbolsInfoTest info; |
| SetupInfo(/*num_symbol_types=*/1, kMinSymbol, kMaxSymbol, /*num_clusters=*/1, |
| &random, &info); |
| DataGenerator generator(info, &random); |
| |
| SymbolRecorder recorder; |
| ASSERT_WP2_OK(recorder.Allocate(info, kNumSymbol)); |
| |
| std::vector<DataGenerator::Symbol> symbols; |
| generator.AddData(0, kNumSymbol, DataGenerator::kHint::kUniform, &recorder, |
| &symbols); |
| |
| // Create the symbol writer. |
| std::unique_ptr<WP2::SymbolWriter> sw(new (WP2Allocable::nothrow) |
| WP2::SymbolWriter); |
| ASSERT_TRUE(sw != nullptr); |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| uint32_t max_nnz = 0; |
| generator.WriteHeader(info, recorder, &max_nnz, &dicts, &enc, sw.get()); |
| |
| // Write the symbols. |
| int32_t min = symbols[0].value, max = symbols[0].value; |
| for (const DataGenerator::Symbol& sym : symbols) { |
| sw->Process(sym.s, sym.cluster, sym.value, "sym", &enc); |
| min = std::min(min, sym.value); |
| max = std::max(max, sym.value); |
| } |
| const uint32_t range = max - min + 1; |
| EXPECT_GT(range, 0u); |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| const uint32_t size = enc.GetBitstreamSize(); |
| |
| // With uniform data, the symbol writer should choose the range storage |
| // method (there is a little overhead for storing the range). |
| ASSERT_NEAR(symbols.size() * std::log2(range) / 8, size, size * 0.0022); |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| // Test the fact that the SymbolWriter/SymbolReader can work when restricting |
| // the range of values. |
| TEST(SymbolsTest, MaxValue) { |
| const uint32_t kNumSymbol = 1000u; |
| const uint32_t kNumSymbolTypes = 5u; |
| const int32_t kMinSymbol = -5; |
| const int32_t kMaxSymbol = 100; |
| const uint32_t kNumClusters = 1u; |
| ASSERT_EQ(kNumClusters, 1u); |
| UniformIntDistribution random(/*seed=*/0); |
| |
| static_assert(kNumSymbolTypes < kSymbolNumMax, "Need less symbols."); |
| |
| // Generate the data. |
| SymbolsInfoTest info; |
| SetupInfo(kNumSymbolTypes, kMinSymbol, kMaxSymbol, kNumClusters, &random, |
| &info); |
| DataGenerator generator(info, &random); |
| std::vector<DataGenerator::Symbol> symbols; |
| |
| SymbolRecorder recorder; |
| ASSERT_WP2_OK(recorder.Allocate(info, kNumSymbol)); |
| |
| for (uint32_t s = 0; s < kNumSymbolTypes; ++s) { |
| generator.AddData( |
| s, kNumSymbol, |
| kDistributions[s * kDistributions.size() / kNumSymbolTypes], &recorder, |
| &symbols); |
| } |
| Shuffle(symbols.begin(), symbols.end(), 0); |
| |
| // Create the symbol writer. |
| std::unique_ptr<WP2::SymbolWriter> sw(new (WP2Allocable::nothrow) |
| WP2::SymbolWriter); |
| ASSERT_TRUE(sw != nullptr); |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| uint32_t max_nnz = 0; |
| generator.WriteHeader(info, recorder, &max_nnz, &dicts, &enc, sw.get()); |
| |
| // Write the data. We artificially cap to multiples of the value. |
| for (float mul : {1.f, 1.2f, 2.f, 100.f}) { |
| for (const DataGenerator::Symbol& sym : symbols) { |
| sw->Process(sym.s, sym.cluster, sym.value, /*max_value=*/ |
| std::abs(sym.value) * mul, "sym", &enc); |
| } |
| } |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| const uint32_t size = bits.size(); |
| ASSERT_EQ(size, enc.GetBitstreamSize()); |
| const uint8_t* buf = bits.data(); |
| |
| // Decode the symbol reader statistics. |
| ExternalDataSource data_source(buf, size); |
| ANSDec dec(&data_source); |
| SymbolReader sr; |
| generator.ReadHeader(info, max_nnz, &dec, &sr); |
| |
| // Verify the symbols do match the original data. |
| for (float mul : {1.f, 1.2f, 2.f, 100.f}) { |
| for (const DataGenerator::Symbol& sym : symbols) { |
| int32_t value; |
| ASSERT_WP2_OK(sr.ReadWithMax(sym.s, sym.cluster, |
| /*max_value=*/std::abs(sym.value) * mul, |
| "sym", &value, nullptr)); |
| ASSERT_EQ(std::abs(value), std::abs(sym.value)) << "mul " << mul; |
| } |
| } |
| ASSERT_WP2_OK(dec.GetStatus()); |
| } |
| |
| // Test the SymbolWriter/SymbolReader for StorageMethod::kAdaptive. |
| TEST(SymbolsTest, StorageMethodAdaptive) { |
| ANSInit(); |
| static constexpr uint32_t kNumSymbol = 1000u; |
| static constexpr uint32_t kMaxSymbol = 10u; |
| UniformIntDistribution random(/*seed=*/0); |
| |
| // Generate the data. |
| SymbolsInfoTest info; |
| info.SetInfo(/*sym=*/0, /*min=*/0, /*max=*/kMaxSymbol, /*num_clusters=*/1, |
| StorageMethod::kAdaptiveSym); |
| DataGenerator generator(info, &random); |
| |
| SymbolRecorder recorder; |
| ASSERT_WP2_OK(recorder.Allocate(info, kNumSymbol)); |
| |
| std::vector<DataGenerator::Symbol> symbols; |
| generator.AddData(0, kNumSymbol, DataGenerator::kHint::kUniform, &recorder, |
| &symbols); |
| |
| // Create the symbol writer. |
| std::unique_ptr<WP2::SymbolWriter> sw(new (WP2Allocable::nothrow) |
| WP2::SymbolWriter); |
| ASSERT_TRUE(sw != nullptr); |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| uint32_t max_nnz = 0; |
| generator.WriteHeader(info, recorder, &max_nnz, &dicts, &enc, sw.get()); |
| |
| EXPECT_EQ(enc.GetCost(), 0); // Header should be free. |
| |
| // Write the symbols. |
| for (const DataGenerator::Symbol& sym : symbols) { |
| sw->Process(sym.s, sym.cluster, sym.value, "sym", &enc); |
| } |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| uint32_t size = enc.GetBitstreamSize(); |
| |
| // Since the data is uniform, the adaptive bit shouldn't win much. |
| EXPECT_NEAR(symbols.size() * std::log2(kMaxSymbol + 1) / 8, size, size * 0.1); |
| |
| // Decode the symbol reader statistics. |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| SymbolReader sr; |
| generator.ReadHeader(info, max_nnz, &dec, &sr); |
| |
| // Verify the symbols do match the original data. |
| for (const DataGenerator::Symbol& sym : symbols) { |
| ASSERT_EQ(sr.Read(sym.s, sym.cluster, "sym", nullptr), sym.value); |
| } |
| ASSERT_WP2_OK(dec.GetStatus()); |
| |
| // Try again with very predictable data. |
| enc.Reset(); |
| for (uint32_t i = 0; i < kNumSymbol; ++i) { |
| // Long string of zeros followed by a long string of ones. |
| const uint32_t value = (i < kNumSymbol / 2) ? 0 : 1; |
| sw->Process(/*sym=*/0, /*cluster=*/0, value, "sym", &enc); |
| } |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| size = enc.GetBitstreamSize(); |
| // Since the data is very correlated it should be pretty cheap. |
| EXPECT_LT(size, kNumSymbol * std::log2(kMaxSymbol + 1) * 0.015); |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| // Test the SymbolWriter/SymbolReader for StorageMethod::kAdaptiveWithAutoSpeed. |
| TEST(SymbolsTest, StorageMethodAdaptiveWithAutoSpeed) { |
| ANSInit(); |
| static constexpr uint32_t kNumSymbol = 100u; |
| static constexpr uint32_t kMaxSymbol = 10u; |
| |
| // Generate the data. |
| SymbolsInfoTest info; |
| info.SetInfo(/*sym=*/0, /*min=*/0, /*max=*/kMaxSymbol, /*num_clusters=*/1, |
| StorageMethod::kAdaptiveWithAutoSpeed); |
| |
| SymbolRecorder recorder; |
| ASSERT_WP2_OK(recorder.Allocate(info, kNumSymbol)); |
| |
| // Record symbols. |
| ANSEncNoop noop; |
| std::vector<DataGenerator::Symbol> symbols; |
| for (uint32_t i = 0; i < kNumSymbol; ++i) { |
| // Long string of zeros followed by a long string of ones. |
| const int32_t value = (i < kNumSymbol / 2) ? 0 : 1; |
| symbols.push_back( |
| {/*s=*/0, value, /*cluster=*/0, /*can_be_negative=*/true}); |
| recorder.Process(/*sym=*/0, /*cluster=*/0, value, "sym", &noop); |
| } |
| |
| // Create the symbol writer. |
| std::unique_ptr<WP2::SymbolWriter> sw(new (WP2Allocable::nothrow) |
| WP2::SymbolWriter); |
| ASSERT_TRUE(sw != nullptr); |
| ASSERT_WP2_OK(sw->Init(info, /*effort=*/5)); |
| ASSERT_WP2_OK(sw->Allocate()); |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| uint32_t max_nnz = kNumSymbol; |
| ASSERT_WP2_OK( |
| sw->WriteHeader( /*sym=*/0, max_nnz, recorder, "header", &enc, &dicts)); |
| |
| for (const DataGenerator::Symbol& sym : symbols) { |
| sw->Process(sym.s, sym.cluster, sym.value, "sym", &enc); |
| } |
| EXPECT_WP2_OK(enc.AssembleToBitstream()); |
| uint32_t size = enc.GetBitstreamSize(); |
| |
| // Since the data is very correlated it should be pretty cheap. |
| EXPECT_LT(size, kNumSymbol * std::log2(kMaxSymbol + 1) * 0.02); |
| |
| // Decode the symbol reader statistics. |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| SymbolReader sr; |
| ASSERT_WP2_OK(sr.Init(info, &dec)); |
| ASSERT_WP2_OK(sr.Allocate()); |
| ASSERT_WP2_OK(sr.ReadHeader(/*sym=*/0, max_nnz, "header")); |
| |
| // Verify the symbols do match the original data. |
| for (const DataGenerator::Symbol& sym : symbols) { |
| ASSERT_EQ(sr.Read(sym.s, sym.cluster, "sym", nullptr), sym.value); |
| } |
| ASSERT_WP2_OK(dec.GetStatus()); |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| enum SymbolCounterType { kFast, kUpdating }; |
| |
| std::unique_ptr<SymbolCounter> CreateSymbolCounter( |
| SymbolCounterType type, const SymbolRecorder& recorder) { |
| return std::unique_ptr<SymbolCounter>( |
| (type == kUpdating) |
| ? new (WP2Allocable::nothrow) UpdatingSymbolCounter(&recorder) |
| : new (WP2Allocable::nothrow) SymbolCounter(&recorder)); |
| } |
| |
| class TestSymbolCounter : public testing::TestWithParam<SymbolCounterType> {}; |
| |
| // Test that UpdatingSymbolCounter gives accurate measurements. |
| TEST_P(TestSymbolCounter, SymbolCounter) { |
| const SymbolCounterType type = GetParam(); |
| |
| WP2MathInit(); |
| ANSInit(); |
| static constexpr uint32_t kNumSymbol = 10000u; |
| static constexpr int32_t kMinSymbol = -5; |
| static constexpr int32_t kMaxSymbol = 100; |
| static constexpr uint32_t kNumClusters = 10u; |
| UniformIntDistribution random(/*seed=*/0); |
| |
| // Generate the data. |
| SymbolsInfoTest info; |
| // Three 'auto' symbols. |
| SetupInfo(/*num_symbol_types=*/3, kMinSymbol, kMaxSymbol, kNumClusters, |
| &random, &info); |
| |
| // Only the UpdatingSymbolCounter is accurate for adaptive bits/symbols. |
| if (type == kUpdating) { |
| // Adaptive bit. |
| info.SetInfo(/*sym=*/info.Size(), /*min=*/0, /*max=*/1, /*num_clusters=*/4, |
| StorageMethod::kAdaptiveBit); |
| EXPECT_WP2_OK(info.SetStartingProba(/*sym=*/info.Size() - 1, /*cluster=*/2, |
| /*p0=*/3, /*p1=*/7)); |
| EXPECT_WP2_OK(info.SetStartingProba(/*sym=*/info.Size() - 1, /*cluster=*/3, |
| /*p0=*/100, /*p1=*/1)); |
| // Adaptive symbol. |
| info.SetInfo(/*sym=*/info.Size(), /*min=*/0, /*max=*/9, /*num_clusters=*/3, |
| StorageMethod::kAdaptiveSym); |
| // Test that binary symbols are not only kAdaptiveBit but can also be |
| // kAdaptiveSym. |
| info.SetInfo(/*sym=*/info.Size(), /*min=*/0, /*max=*/1, /*num_clusters=*/2, |
| StorageMethod::kAdaptiveSym); |
| } |
| |
| const uint32_t num_symbol_types = info.Size(); |
| DataGenerator generator(info, &random); |
| |
| std::vector<DataGenerator::Symbol> symbols; |
| |
| SymbolRecorder recorder; |
| ASSERT_WP2_OK(recorder.Allocate(info, kNumSymbol)); |
| |
| for (uint32_t s = 0; s < num_symbol_types; ++s) { |
| const DataGenerator::kHint distribution = |
| (info.Method(s) == StorageMethod::kAdaptiveSym) |
| ? DataGenerator::kHint::kDict |
| : DataGenerator::kHint::kUniform; |
| generator.AddData(s, kNumSymbol, distribution, &recorder, &symbols); |
| } |
| Shuffle(symbols.begin(), symbols.end(), 0); |
| |
| // Create the symbol writer and write the header. |
| std::unique_ptr<WP2::SymbolWriter> sw(new (WP2Allocable::nothrow) |
| WP2::SymbolWriter); |
| ASSERT_TRUE(sw != nullptr); |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| uint32_t max_nnz = 0; |
| generator.WriteHeader(info, recorder, &max_nnz, &dicts, &enc, sw.get()); |
| const float header_cost = enc.GetCost(); |
| |
| // Write the symbols. |
| for (const DataGenerator::Symbol& sym : symbols) { |
| sw->Process(sym.s, sym.cluster, sym.value, "sym", &enc); |
| } |
| const float real_cost = enc.GetCost() - header_cost; |
| |
| ASSERT_WP2_OK(recorder.ResetRecord(/*reset_backup=*/true)); |
| ANSEncCounter enc_counter; |
| std::unique_ptr<SymbolCounter> symbol_counter = |
| CreateSymbolCounter(type, recorder); |
| ASSERT_WP2_OK(symbol_counter->Allocate(/*syms=*/{3, 4, 5})); |
| for (const DataGenerator::Symbol& sym : symbols) { |
| symbol_counter->Process(sym.s, sym.cluster, sym.value, "sym", &enc_counter); |
| } |
| |
| // A small inaccuracy is expected because of "auto" symbols, whose exact cost |
| // is not known in advance. |
| EXPECT_NEAR(enc_counter.GetCost(), real_cost, 2); |
| } |
| |
| // Checks that we can perfectly predict the cost of adaptive bits/symbols. |
| TEST_P(TestSymbolCounter, AdaptiveSymbols) { |
| const SymbolCounterType type = GetParam(); |
| |
| WP2MathInit(); |
| ANSInit(); |
| UniformIntDistribution random(/*seed=*/0); |
| |
| // Generate the data. |
| SymbolsInfoTest info; |
| |
| // Adaptive bit. |
| info.SetInfo(/*sym=*/info.Size(), /*min=*/0, /*max=*/1, /*num_clusters=*/4, |
| StorageMethod::kAdaptiveBit); |
| EXPECT_WP2_OK(info.SetStartingProba(/*sym=*/info.Size() - 1, /*cluster=*/2, |
| /*p0=*/3, /*p1=*/7)); |
| EXPECT_WP2_OK(info.SetStartingProba(/*sym=*/info.Size() - 1, /*cluster=*/3, |
| /*p0=*/100, /*p1=*/1)); |
| // Adaptive symbol. |
| info.SetInfo(/*sym=*/info.Size(), /*min=*/0, /*max=*/9, /*num_clusters=*/3, |
| StorageMethod::kAdaptiveSym); |
| // Add starting probas for one of the clusters. |
| constexpr uint16_t kCDF[10] = {0, 2000, 2200, 2300, 4000, |
| 4020, 4600, 5000, 6000, 6500}; |
| constexpr uint16_t kMaxProba = 8000; |
| EXPECT_WP2_OK(info.SetInitialCDF(/*sym=*/info.Size() - 1, /*cluster=*/1, kCDF, |
| kMaxProba)); |
| |
| // Test that binary symbols are not only kAdaptiveBit but can also be |
| // kAdaptiveSym. |
| info.SetInfo(/*sym=*/info.Size(), /*min=*/0, /*max=*/1, /*num_clusters=*/2, |
| StorageMethod::kAdaptiveSym); |
| |
| const uint32_t num_symbol_types = info.Size(); |
| DataGenerator generator(info, &random); |
| |
| std::vector<DataGenerator::Symbol> symbols; |
| |
| // The fast symbol counter does not update adaptive bit/symbols, therefore it |
| // is only accurate for the first instance of each symbol. So we only |
| // generate one value per symbol. |
| const uint32_t num_symbol = (type == kFast) ? 1 : 100u; |
| |
| SymbolRecorder recorder; |
| ASSERT_WP2_OK(recorder.Allocate(info, num_symbol)); |
| |
| for (uint32_t s = 0; s < num_symbol_types; ++s) { |
| generator.AddData(s, num_symbol, DataGenerator::kHint::kDict, &recorder, |
| &symbols); |
| } |
| Shuffle(symbols.begin(), symbols.end(), 0); |
| |
| // Create the symbol writer and write the header. |
| std::unique_ptr<WP2::SymbolWriter> sw(new (WP2Allocable::nothrow) |
| WP2::SymbolWriter); |
| ASSERT_TRUE(sw != nullptr); |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| uint32_t max_nnz = 0; |
| generator.WriteHeader(info, recorder, &max_nnz, &dicts, &enc, sw.get()); |
| const float header_cost = enc.GetCost(); |
| |
| // Write the symbols. |
| for (const DataGenerator::Symbol& sym : symbols) { |
| sw->Process(sym.s, sym.cluster, sym.value, "sym", &enc); |
| } |
| const float real_cost = enc.GetCost() - header_cost; |
| |
| ASSERT_WP2_OK(recorder.ResetRecord(/*reset_backup=*/true)); |
| ANSEncCounter enc_counter; |
| std::unique_ptr<SymbolCounter> symbol_counter = |
| CreateSymbolCounter(type, recorder); |
| ASSERT_WP2_OK(symbol_counter->Allocate(/*syms=*/{0, 1, 2})); |
| |
| for (const DataGenerator::Symbol& sym : symbols) { |
| symbol_counter->Process(sym.s, sym.cluster, sym.value, "sym", &enc_counter); |
| } |
| |
| EXPECT_NEAR(enc_counter.GetCost(), real_cost, 0.01f); |
| } |
| |
| TEST_P(TestSymbolCounter, SymbolCounter_Dictionary) { |
| const SymbolCounterType type = GetParam(); |
| |
| WP2MathInit(); |
| // Generate the data. |
| SymbolsInfoTest info; |
| info.SetInfo(0, /*min=*/0, /*max=*/7, /*num_clusters=*/1, |
| StorageMethod::kAuto); |
| |
| SymbolRecorder recorder; |
| ASSERT_WP2_OK(recorder.Allocate(info, /*num_records=*/0)); |
| |
| ANSEncCounter counter; |
| std::unique_ptr<SymbolCounter> symbol_counter = |
| CreateSymbolCounter(type, recorder); |
| ASSERT_WP2_OK(symbol_counter->Allocate(/*syms=*/{})); |
| symbol_counter->Process(0, 1, "test", &counter); |
| // Before anything was recorded, all values are assumed to have the same cost |
| // which is log2(range) = log2(8) = 3 |
| EXPECT_EQ(counter.GetCost(), 3); |
| counter.Reset(); |
| symbol_counter->Process(0, 3, "test", &counter); |
| EXPECT_EQ(counter.GetCost(), 3); |
| |
| ANSEncNoop noop; |
| std::vector<uint8_t> data(400); |
| for (uint32_t i = 0; i < 400; ++i) data[i] = (i < 100) ? 1 : 3; |
| std::mt19937 gen(0); |
| std::shuffle(data.begin(), data.end(), gen); |
| for (uint8_t i : data) recorder.Process(/*sym=*/0, i, "test", &noop); |
| |
| counter.Reset(); |
| symbol_counter->Process(0, 1, "test", &counter); |
| // After recording, costs are based on the recorded stats. -log2(1/4) = 2 |
| EXPECT_EQ(counter.GetCost(), 2); |
| counter.Reset(); |
| symbol_counter->Process(0, 3, "test", &counter); |
| // -log2(3/4) |
| EXPECT_NEAR(counter.GetCost(), 0.415, 0.001); |
| counter.Reset(); |
| symbol_counter->Process(0, 2, "test", &counter); |
| // Value "2" was never recorded so we assume a cost above -1*log2(1/400). |
| EXPECT_NEAR(counter.GetCost(), 8.643, 0.001); |
| ASSERT_WP2_OK(recorder.MakeBackup()); |
| float costs[2]; |
| for (uint32_t reset_backup : {false, true}) { |
| ASSERT_WP2_OK(recorder.ResetRecord(reset_backup)); |
| counter.Reset(); |
| for (uint8_t i : data) { |
| symbol_counter->Process(/*sym=*/0, i, "test", &counter); |
| } |
| costs[reset_backup ? 1 : 0] = counter.GetCost(); |
| } |
| // If we use the recorded probabilities, we are much better at predicting the |
| // cost. |
| EXPECT_LT(costs[0], 0.5 * costs[1]); |
| } |
| |
| // Test that SymbolCounter gives accurate measurements when given a |
| // SymbolRecorder that has already recorded a bunch of data. |
| TEST_P(TestSymbolCounter, SymbolCounter_Midway) { |
| const SymbolCounterType type = GetParam(); |
| |
| WP2MathInit(); |
| ANSInit(); |
| static constexpr uint32_t kNumSymbol = 10000u; |
| UniformIntDistribution random(/*seed=*/0); |
| |
| // Only the UpdatingSymbolCounter is accurate for adaptive bits/symbols. |
| SymbolsInfoTest info; |
| |
| // Here we only test adaptive bits and symbols, for which the counter should |
| // know exactly how much space they'll take. |
| |
| // Adaptive bit. |
| info.SetInfo(0, /*min=*/0, /*max=*/1, /*num_clusters=*/4, |
| StorageMethod::kAdaptiveBit); |
| EXPECT_WP2_OK( |
| info.SetStartingProba(/*sym=*/0, /*cluster=*/2, /*p0=*/3, /*p1=*/7)); |
| EXPECT_WP2_OK( |
| info.SetStartingProba(/*sym=*/0, /*cluster=*/3, /*p0=*/100, /*p1=*/1)); |
| // Adaptive symbol. |
| info.SetInfo(1, /*min=*/0, /*max=*/9, /*num_clusters=*/3, |
| StorageMethod::kAdaptiveSym); |
| |
| const uint32_t num_symbol_types = info.Size(); |
| DataGenerator generator(info, &random); |
| |
| std::vector<DataGenerator::Symbol> symbols; |
| |
| SymbolRecorder recorder; |
| ASSERT_WP2_OK(recorder.Allocate(info, kNumSymbol)); |
| |
| // Generate data. |
| for (uint32_t s = 0; s < num_symbol_types; ++s) { |
| generator.AddData( |
| s, kNumSymbol, |
| kDistributions[s * kDistributions.size() / num_symbol_types], &recorder, |
| &symbols); |
| } |
| Shuffle(symbols.begin(), symbols.end(), 0); |
| |
| // Create the symbol writer and write the header. |
| std::unique_ptr<WP2::SymbolWriter> sw(new (WP2Allocable::nothrow) |
| WP2::SymbolWriter); |
| ASSERT_TRUE(sw != nullptr); |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| uint32_t max_nnz = 0; |
| generator.WriteHeader(info, recorder, &max_nnz, &dicts, &enc, sw.get()); |
| |
| ASSERT_WP2_OK(recorder.ResetRecord(/*reset_backup=*/true)); |
| |
| // Write and record half of the symbols. |
| const uint32_t half = symbols.size() / 2; |
| for (uint32_t i = 0; i < half; ++i) { |
| const DataGenerator::Symbol& sym = symbols[i]; |
| sw->Process(sym.s, sym.cluster, sym.value, "sym", &enc); |
| recorder.Process(sym.s, sym.cluster, sym.value, "sym", &enc); |
| } |
| |
| const float half_cost = enc.GetCost(); |
| |
| // Has this symbol been seen? per symbol type, per cluster. |
| std::vector<std::vector<bool>> seen(info.Size()); |
| for (uint32_t i = 0; i < info.Size(); ++i) { |
| seen[i].resize(info.NumClusters(i)); |
| } |
| |
| // Write the second half of symbols. |
| for (uint32_t i = half; i < symbols.size(); ++i) { |
| const DataGenerator::Symbol& sym = symbols[i]; |
| // The "fast" symbol counter doesn't update adaptive symbols, so it's only |
| // accurate for the first instance of each symbol. Therefore we only |
| // process one value for each symbol type/cluster. |
| if (type == kFast && seen[sym.s][sym.cluster]) continue; |
| |
| seen[sym.s][sym.cluster] = true; |
| sw->Process(sym.s, sym.cluster, sym.value, "sym", &enc); |
| } |
| const float real_cost = enc.GetCost() - half_cost; |
| |
| ANSEncCounter enc_counter; |
| std::unique_ptr<SymbolCounter> symbol_counter = |
| CreateSymbolCounter(type, recorder); |
| ASSERT_WP2_OK(symbol_counter->Allocate(/*syms=*/{0, 1})); |
| |
| for (uint32_t i = 0; i < info.Size(); ++i) { |
| std::fill(seen[i].begin(), seen[i].end(), false); |
| } |
| for (uint32_t i = half; i < symbols.size(); ++i) { |
| const DataGenerator::Symbol& sym = symbols[i]; |
| // The "fast" symbol counter doesn't update adaptive symbols, so it's only |
| // accurate for the first instance of each symbol. Therefore we only |
| // process one value for each symbol type/cluster. |
| if (type == kFast && seen[sym.s][sym.cluster]) continue; |
| seen[sym.s][sym.cluster] = true; |
| symbol_counter->Process(sym.s, sym.cluster, sym.value, "sym", &enc_counter); |
| } |
| |
| EXPECT_NEAR(enc_counter.GetCost(), real_cost, 1); |
| } |
| |
| struct PrefixCodeSymbol { |
| uint32_t value; |
| uint32_t range; |
| uint32_t prefix_size; // 0 to 10 for the tests |
| }; |
| |
| TEST(SymbolsTest, SimplePrefixCode) { |
| WP2MathInit(); |
| static constexpr uint32_t kNumSymbol = 10000u; |
| |
| // Make sure the prefix coding representation is bijective. |
| for (uint32_t prefix_size = 0; prefix_size <= 10; ++prefix_size) { |
| for (uint32_t prefix = 0; prefix < 20; ++prefix) { |
| const uint32_t extra_bits_num = |
| PrefixCode::NumExtraBits(prefix, prefix_size); |
| for (uint32_t extra_bits_value = 0; |
| extra_bits_value < (1u << extra_bits_num); ++extra_bits_value) { |
| const uint32_t value = |
| PrefixCode::Merge(prefix, prefix_size, extra_bits_value); |
| const PrefixCode prefix_code(value, prefix_size); |
| ASSERT_EQ(prefix_code.extra_bits_num, extra_bits_num); |
| ASSERT_EQ(prefix_code.extra_bits_value, extra_bits_value); |
| ASSERT_EQ(prefix_code.prefix, prefix); |
| } |
| } |
| } |
| |
| UniformIntDistribution random(/*seed=*/0); |
| std::vector<PrefixCodeSymbol> symbols; |
| symbols.reserve(kNumSymbol); |
| for (uint32_t i = 0; i < kNumSymbol; ++i) { |
| const uint32_t range = random.Get(1u, kANSMaxRange); |
| const uint32_t value = random.Get(0u, range - 1); |
| const uint32_t prefix_size = random.Get(0u, 10u); |
| const PrefixCodeSymbol sym = {value, range, prefix_size}; |
| const PrefixCode prefix_code(value, prefix_size); |
| ASSERT_EQ(prefix_code.extra_bits_num, |
| PrefixCode::NumExtraBits(prefix_code.prefix, prefix_size)); |
| ASSERT_EQ(value, PrefixCode::Merge(prefix_code.prefix, prefix_size, |
| prefix_code.extra_bits_value)); |
| symbols.push_back(sym); |
| } |
| |
| ANSEnc enc; |
| for (PrefixCodeSymbol& s : symbols) { |
| WritePrefixCode(s.value, /*min=*/0, /*max=*/s.range - 1, s.prefix_size, |
| &enc, "test"); |
| } |
| |
| ASSERT_WP2_OK(enc.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| for (PrefixCodeSymbol& s : symbols) { |
| ASSERT_EQ(s.value, (uint32_t)ReadPrefixCode(/*min=*/0, /*max=*/s.range - 1, |
| s.prefix_size, &dec, "test")); |
| } |
| } |
| |
| INSTANTIATE_TEST_SUITE_P(TestSymbolCounterInstanciation, TestSymbolCounter, |
| testing::Values(kUpdating)); |
| |
| //------------------------------------------------------------------------------ |
| |
| class SymbolWriterForTest : public SymbolWriter { |
| public: |
| using SymbolWriter::AddAdaptiveBit; |
| using SymbolWriter::AddAdaptiveSymbol; |
| using SymbolWriter::AddDict; |
| using SymbolWriter::AddPrefixCode; |
| using SymbolWriter::AddRange; |
| using SymbolWriter::AddTrivial; |
| }; |
| |
| class SymbolReaderForTest : public SymbolReader { |
| public: |
| using SymbolReader::AddAdaptiveBit; |
| using SymbolReader::AddAdaptiveSymbol; |
| using SymbolReader::AddDict; |
| using SymbolReader::AddPrefixCode; |
| using SymbolReader::AddRange; |
| using SymbolReader::AddTrivial; |
| }; |
| |
| // bool: use_max |
| class SymbolCostTest : public testing::TestWithParam<std::tuple<bool, bool>> { |
| protected: |
| float Decode(SymbolReader* const sr, const std::vector<int32_t>& values, |
| const std::vector<uint32_t>& max_values, float enc_cost) { |
| float dec_cost = enc_cost; |
| float* cost_ptr = nullptr; |
| #if defined(WP2_BITTRACE) |
| dec_cost = 0; |
| cost_ptr = &dec_cost; |
| #endif |
| for (uint32_t i = 0; i < kNumValues; ++i) { |
| SCOPED_TRACE(SPrintf("value %d", i)); |
| if (!max_values.empty()) { |
| int32_t v = 0; |
| EXPECT_WP2_OK(sr->ReadWithMax(kSymbol, kCluster, max_values[i], |
| "label", &v, cost_ptr)); |
| EXPECT_EQ(values[i], v); |
| } else { |
| EXPECT_EQ(values[i], sr->Read(kSymbol, "label", cost_ptr)); |
| } |
| } |
| return dec_cost; |
| } |
| |
| UniformIntDistribution random_; |
| |
| static constexpr uint32_t kCluster = 0; |
| static constexpr uint32_t kSymbol = 0; |
| static constexpr uint32_t kNumValues = 1000; |
| static constexpr int16_t kMinSymbol = -15; |
| static constexpr int16_t kMaxSymbol = 10; |
| // Allow 0.05% error per value. |
| static constexpr float kCostError = |
| (kNumValues * 0.05 / 100) < 0.01 ? 0.01 : (kNumValues * 0.05 / 100); |
| }; |
| |
| TEST_P(SymbolCostTest, TrivialCost) { |
| const bool use_max = std::get<0>(GetParam()); |
| const bool is_symmetric = std::get<1>(GetParam()); |
| |
| const int16_t min_symbol = use_max ? 0 |
| : is_symmetric ? -kMaxSymbol |
| : kMinSymbol; |
| |
| const int16_t kValue = 8; |
| |
| SymbolsInfo info; |
| info.SetInfo(kSymbol, min_symbol, kMaxSymbol, /*num_clusters=*/1, |
| StorageMethod::kAuto); |
| |
| SymbolWriterForTest sw; |
| ASSERT_WP2_OK(sw.Init(info, /*effort=*/5)); |
| ASSERT_WP2_OK(sw.Allocate()); |
| sw.AddTrivial(kSymbol, kCluster, is_symmetric, |
| is_symmetric ? kValue : kValue - min_symbol); |
| |
| std::vector<int32_t> values; |
| std::vector<uint32_t> max_values; |
| float enc_cost = 0.f; |
| ANSEnc enc; |
| for (uint32_t i = 0; i < kNumValues; ++i) { |
| const int32_t v = |
| is_symmetric && !use_max && random_.Get(0, 1) ? -kValue : kValue; |
| values.push_back(v); |
| if (use_max) { |
| const uint32_t max = random_.Get(kValue, kMaxSymbol); |
| sw.ProcessWithCost(kSymbol, kCluster, v, max, "label", &enc, &enc_cost); |
| max_values.push_back(max); |
| } else { |
| sw.ProcessWithCost(kSymbol, kCluster, v, "label", &enc, &enc_cost); |
| } |
| } |
| |
| ASSERT_WP2_OK(enc.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| |
| SymbolReaderForTest sr; |
| ASSERT_WP2_OK(sr.Init(info, &dec)); |
| ASSERT_WP2_OK(sr.Allocate()); |
| |
| sr.AddTrivial(kSymbol, kCluster, is_symmetric, |
| is_symmetric ? kValue : kValue - min_symbol); |
| |
| float dec_cost = Decode(&sr, values, max_values, enc_cost); |
| |
| EXPECT_EQ(enc_cost, dec_cost); |
| if (!is_symmetric) { |
| EXPECT_EQ(0., enc.GetCost()); |
| EXPECT_EQ(0., enc_cost); |
| } |
| } |
| |
| TEST_P(SymbolCostTest, RangeCost) { |
| const bool use_max = std::get<0>(GetParam()); |
| const bool is_symmetric = std::get<1>(GetParam()); |
| |
| SymbolsInfo info; |
| const int16_t min_symbol = use_max ? 0 |
| : is_symmetric ? -kMaxSymbol |
| : kMinSymbol; |
| const uint32_t range = kMaxSymbol - min_symbol + 1; |
| info.SetInfo(kSymbol, min_symbol, kMaxSymbol, /*num_clusters=*/1, |
| StorageMethod::kAuto); |
| |
| SymbolWriterForTest sw; |
| ASSERT_WP2_OK(sw.Init(info, /*effort=*/5)); |
| ASSERT_WP2_OK(sw.Allocate()); |
| sw.AddRange(kSymbol, kCluster, is_symmetric, /*mapping=*/nullptr, |
| /*size=*/0, range); |
| |
| std::vector<int32_t> values; |
| std::vector<uint32_t> max_values; |
| float enc_cost = 0.f; |
| ANSEnc enc; |
| for (uint32_t i = 0; i < kNumValues; ++i) { |
| int16_t v; |
| if (use_max) { |
| const uint32_t max = random_.Get<uint32_t>(0, kMaxSymbol); |
| v = random_.Get(0u, max); |
| max_values.push_back(max); |
| sw.ProcessWithCost(kSymbol, kCluster, v, max, "label", &enc, &enc_cost); |
| } else { |
| v = random_.Get(min_symbol, kMaxSymbol); |
| sw.ProcessWithCost(kSymbol, kCluster, v, "label", &enc, &enc_cost); |
| } |
| values.push_back(v); |
| } |
| |
| ASSERT_WP2_OK(enc.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| |
| SymbolReaderForTest sr; |
| ASSERT_WP2_OK(sr.Init(info, &dec)); |
| ASSERT_WP2_OK(sr.Allocate()); |
| |
| sr.AddRange(kSymbol, kCluster, is_symmetric, range); |
| |
| const float dec_cost = Decode(&sr, values, max_values, enc_cost); |
| |
| EXPECT_NEAR(enc.GetCost(), enc_cost, kCostError); |
| EXPECT_NEAR(enc.GetCost(), dec_cost, kCostError); |
| } |
| |
| TEST_P(SymbolCostTest, DictCost) { |
| const bool use_max = std::get<0>(GetParam()); |
| const bool is_symmetric = std::get<1>(GetParam()); |
| |
| SymbolsInfo info; |
| const int16_t min_symbol = use_max ? 0 |
| : is_symmetric ? -kMaxSymbol |
| : kMinSymbol; |
| info.SetInfo(kSymbol, min_symbol, kMaxSymbol, /*num_clusters=*/1, |
| StorageMethod::kAuto); |
| |
| std::vector<uint32_t> histogram = {10, 2, 50, 20, 5}; |
| std::vector<uint16_t> mapping = {0, 2, 3, 6, 7}; |
| const uint32_t symbol_size = histogram.size(); |
| ANSCodes infos; |
| ASSERT_TRUE(infos.resize(symbol_size)); |
| for (uint32_t i = 0; i < symbol_size; ++i) { |
| infos[i].freq = histogram[i]; |
| infos[i].symbol = mapping[i]; |
| } |
| |
| SymbolWriterForTest sw; |
| ASSERT_WP2_OK(sw.Init(info, /*effort=*/5)); |
| ASSERT_WP2_OK(sw.Allocate()); |
| |
| ANSDictionaries dicts; |
| ASSERT_WP2_OK(sw.AddDict(kSymbol, kCluster, is_symmetric, histogram.data(), |
| histogram.data(), mapping.data(), symbol_size, |
| &dicts)); |
| |
| std::vector<int32_t> values; |
| std::vector<uint32_t> max_values; |
| float enc_cost = 0.f; |
| ANSEnc enc; |
| for (uint32_t i = 0; i < kNumValues; ++i) { |
| const uint32_t max_i = use_max |
| ? random_.Get<uint32_t>(0, mapping.size() - 1) |
| : mapping.size() - 1; |
| const uint32_t max = mapping[max_i]; |
| int16_t v = mapping[random_.Get<uint32_t>(0, max_i)]; |
| v = is_symmetric ? random_.Get<int32_t>(0, 1) ? -v : v : min_symbol + v; |
| values.push_back(v); |
| if (use_max) { |
| max_values.push_back(max); |
| sw.ProcessWithCost(kSymbol, kCluster, v, max, "label", &enc, &enc_cost); |
| } else { |
| sw.ProcessWithCost(kSymbol, kCluster, v, "label", &enc, &enc_cost); |
| } |
| } |
| |
| ASSERT_WP2_OK(enc.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| |
| SymbolReaderForTest sr; |
| ASSERT_WP2_OK(sr.Init(info, &dec)); |
| ASSERT_WP2_OK(sr.Allocate()); |
| |
| ASSERT_WP2_OK(sr.AddDict(kSymbol, kCluster, is_symmetric, &infos)); |
| |
| const float dec_cost = Decode(&sr, values, max_values, enc_cost); |
| |
| EXPECT_NEAR(enc.GetCost(), enc_cost, kCostError); |
| EXPECT_NEAR(enc.GetCost(), dec_cost, kCostError); |
| } |
| |
| TEST_P(SymbolCostTest, PrefixCodeCost) { |
| const bool use_max = std::get<0>(GetParam()); |
| const bool is_symmetric = std::get<1>(GetParam()); |
| |
| SymbolsInfo info; |
| const int16_t min_symbol = use_max ? 0 |
| : is_symmetric ? -kMaxSymbol |
| : kMinSymbol; |
| info.SetInfo(kSymbol, min_symbol, kMaxSymbol, /*num_clusters=*/1, |
| StorageMethod::kAuto); |
| |
| std::vector<uint16_t> mapping = {1, 2, 4, 6, 9}; |
| |
| const uint32_t kPrefixCodePrefixSize = 0; |
| std::vector<uint32_t> prefix_code_histogram = {3, 7, 1, 2, 1}; |
| std::vector<uint16_t> prefix_code_mapping = {0, 1, 2, 3, 4}; |
| const uint32_t prefix_code_size = prefix_code_histogram.size(); |
| ANSCodes prefix_code_infos; |
| ASSERT_TRUE(prefix_code_infos.resize(prefix_code_size)); |
| for (uint32_t i = 0; i < prefix_code_size; ++i) { |
| prefix_code_infos[i].freq = prefix_code_histogram[i]; |
| prefix_code_infos[i].symbol = prefix_code_mapping[i]; |
| } |
| |
| SymbolWriterForTest sw; |
| ASSERT_WP2_OK(sw.Init(info, /*effort=*/5)); |
| ASSERT_WP2_OK(sw.Allocate()); |
| ANSDictionaries dicts; |
| ASSERT_WP2_OK(sw.AddPrefixCode( |
| kSymbol, kCluster, is_symmetric, prefix_code_histogram.data(), |
| prefix_code_histogram.data(), prefix_code_mapping.data(), |
| prefix_code_size, kPrefixCodePrefixSize, &dicts)); |
| |
| std::vector<int32_t> values; |
| std::vector<uint32_t> max_values; |
| float enc_cost = 0.f; |
| ANSEnc enc; |
| for (uint32_t i = 0; i < kNumValues; ++i) { |
| const uint32_t max_i = use_max |
| ? random_.Get<uint32_t>(0, mapping.size() - 1) |
| : mapping.size() - 1; |
| const uint32_t max = mapping[max_i]; |
| int16_t v = mapping[random_.Get<uint32_t>(0, max_i)]; |
| v = is_symmetric ? random_.Get<int32_t>(0, 1) ? -v : v : min_symbol + v; |
| values.push_back(v); |
| if (use_max) { |
| max_values.push_back(max); |
| sw.ProcessWithCost(kSymbol, kCluster, v, max, "label", &enc, &enc_cost); |
| } else { |
| sw.ProcessWithCost(kSymbol, kCluster, v, "label", &enc, &enc_cost); |
| } |
| } |
| |
| ASSERT_WP2_OK(enc.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| |
| SymbolReaderForTest sr; |
| ASSERT_WP2_OK(sr.Init(info, &dec)); |
| ASSERT_WP2_OK(sr.Allocate()); |
| |
| ASSERT_WP2_OK(sr.AddPrefixCode(kSymbol, kCluster, is_symmetric, |
| &prefix_code_infos, kPrefixCodePrefixSize)); |
| |
| const float dec_cost = Decode(&sr, values, max_values, enc_cost); |
| |
| EXPECT_NEAR(enc.GetCost(), enc_cost, kCostError); |
| EXPECT_NEAR(enc.GetCost(), dec_cost, kCostError); |
| } |
| |
| TEST_P(SymbolCostTest, AdaptiveBitCost) { |
| WP2MathInit(); |
| const bool use_max = std::get<0>(GetParam()); |
| const bool skip = std::get<1>(GetParam()); |
| if (skip) return; |
| |
| SymbolsInfo info; |
| info.SetInfo(kSymbol, /*min=*/0, /*max=*/1, /*num_clusters=*/1, |
| StorageMethod::kAdaptiveBit); |
| ANSDictionaries dicts; |
| |
| const uint32_t p0 = 3, p1 = 10; |
| |
| SymbolWriterForTest sw; |
| ASSERT_WP2_OK(sw.Init(info, /*effort=*/5)); |
| ASSERT_WP2_OK(sw.Allocate()); |
| ASSERT_WP2_OK(sw.AddAdaptiveBit(kSymbol, kCluster, p0, p1)); |
| |
| std::vector<int32_t> values; |
| std::vector<uint32_t> max_values; |
| float enc_cost = 0.f; |
| ANSEnc enc; |
| for (uint32_t i = 0; i < kNumValues; ++i) { |
| const uint32_t max = use_max ? random_.FlipACoin() : 1; |
| const uint32_t v = random_.Get<uint32_t>(0, max); |
| values.push_back(v); |
| if (use_max) { |
| max_values.push_back(max); |
| sw.ProcessWithCost(kSymbol, kCluster, v, max, "label", &enc, &enc_cost); |
| } else { |
| sw.ProcessWithCost(kSymbol, kCluster, v, "label", &enc, &enc_cost); |
| } |
| } |
| |
| ASSERT_WP2_OK(enc.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| |
| SymbolReaderForTest sr; |
| ASSERT_WP2_OK(sr.Init(info, &dec)); |
| ASSERT_WP2_OK(sr.Allocate()); |
| |
| ASSERT_WP2_OK(sr.AddAdaptiveBit(kSymbol, kCluster, p0, p1)); |
| |
| const float dec_cost = Decode(&sr, values, max_values, enc_cost); |
| |
| EXPECT_NEAR(enc.GetCost(), enc_cost, kCostError); |
| EXPECT_NEAR(enc.GetCost(), dec_cost, kCostError); |
| } |
| |
| TEST_P(SymbolCostTest, AdaptiveSymCost) { |
| ANSInit(); |
| const bool use_max = std::get<0>(GetParam()); |
| const bool skip = std::get<1>(GetParam()); |
| if (skip) return; |
| |
| for (int method = 0; method < (int)ANSAdaptiveSymbol::Method::kNum; |
| ++method) { |
| SCOPED_TRACE(SPrintf("method %d", method)); |
| const uint32_t speed = (method == (int)ANSAdaptiveSymbol::Method::kAOM) |
| ? kANSAProbaInvalidSpeed |
| : 42; |
| |
| SymbolsInfo info; |
| info.SetInfo(kSymbol, /*min=*/0, /*max=*/kMaxSymbol, /*num_clusters=*/1, |
| StorageMethod::kAdaptiveSym); |
| |
| SymbolWriterForTest sw; |
| ASSERT_WP2_OK(sw.Init(info, /*effort=*/5)); |
| ASSERT_WP2_OK(sw.Allocate()); |
| ASSERT_WP2_OK(sw.AddAdaptiveSymbol( |
| kSymbol, kCluster, (ANSAdaptiveSymbol::Method)method, speed)); |
| |
| std::vector<int32_t> values; |
| std::vector<uint32_t> max_values; |
| float enc_cost = 0.f; |
| ANSEnc enc; |
| for (uint32_t i = 0; i < kNumValues; ++i) { |
| const uint32_t max = |
| use_max ? random_.Get<uint32_t>(0, kMaxSymbol) : kMaxSymbol; |
| const uint32_t v = random_.Get<uint32_t>(0, max); |
| values.push_back(v); |
| if (use_max) { |
| max_values.push_back(max); |
| sw.ProcessWithCost(kSymbol, kCluster, v, max, "label", &enc, &enc_cost); |
| } else { |
| sw.ProcessWithCost(kSymbol, kCluster, v, "label", &enc, &enc_cost); |
| } |
| } |
| |
| ASSERT_WP2_OK(enc.AssembleToBitstream()); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| |
| SymbolReaderForTest sr; |
| ASSERT_WP2_OK(sr.Init(info, &dec)); |
| ASSERT_WP2_OK(sr.Allocate()); |
| |
| ASSERT_WP2_OK(sr.AddAdaptiveSymbol( |
| kSymbol, kCluster, (ANSAdaptiveSymbol::Method)method, speed)); |
| |
| const float dec_cost = Decode(&sr, values, max_values, enc_cost); |
| |
| EXPECT_NEAR(enc.GetCost(), enc_cost, kCostError); |
| EXPECT_NEAR(enc.GetCost(), dec_cost, kCostError); |
| } |
| } |
| |
| INSTANTIATE_TEST_SUITE_P(SymbolReaderTestInstantiation, SymbolCostTest, |
| testing::Combine(testing::Values(false, true), |
| testing::Values(false, true))); |
| |
| //------------------------------------------------------------------------------ |
| |
| // Test an edge case with the capped SymbolWriter/SymbolReader. |
| #if !(defined(WP2_BITTRACE) || defined(WP2_TRACE) || defined(WP2_ENC_DEC_MATCH)) |
| // Only test when there is no tracing as we will artificially set the state of |
| // the ANS decoder with a U value and read it as a symbol. |
| TEST(SymbolsTest, MaxValuePrecision) { |
| // Define the faulty statistics. |
| constexpr uint32_t kRange = 6; |
| SymbolsInfoTest info; |
| info.SetInfo(0, /*min=*/0, /*max=*/kRange - 1, /*num_clusters=*/1, |
| StorageMethod::kAuto); |
| constexpr uint32_t counts[kRange] = {4096, 2048, 8192, 512, 1024, 512}; |
| |
| // Create the symbol writer. |
| std::unique_ptr<WP2::SymbolWriter> sw(new (WP2Allocable::nothrow) |
| WP2::SymbolWriter); |
| ASSERT_TRUE(sw != nullptr); |
| ANSEnc enc; |
| ANSDictionaries dicts; |
| ASSERT_WP2_OK(sw->Init(info, /*effort=*/5)); |
| ASSERT_WP2_OK(sw->Allocate()); |
| constexpr uint32_t max_nnz = 10000; |
| ASSERT_WP2_OK(sw->WriteHeader(/*sym=*/0, /*cluster=*/0, max_nnz, counts, |
| "counts", &enc, &dicts)); |
| |
| // This value is actually the value of the ANS state that can crash the capped |
| // decoding. It is right at the end of the last interval, which is not |
| // properly defined: it goes up to 16382 if we use the usual roundings, but |
| // we want it to go to the end of the tab size. |
| enc.PutUValue(ANS_TAB_SIZE - 1, ANS_LOG_TAB_SIZE, "hack"); |
| // No need to write data as the stats are stored anyway. Just wrap up. |
| EXPECT_WP2_OK(enc.AssembleToBitstream(/*clear_tokens=*/true)); |
| Vector_u8 bits; |
| EXPECT_WP2_OK(enc.WriteBitstreamTo(bits)); |
| // Decode the symbol reader statistics. |
| ExternalDataSource data_source(bits.data(), bits.size()); |
| ANSDec dec(&data_source); |
| SymbolReader sr; |
| ASSERT_WP2_OK(sr.Init(info, &dec)); |
| ASSERT_WP2_OK(sr.Allocate()); |
| ASSERT_WP2_OK(sr.ReadHeader(/*sym=*/0, /*cluster=*/0, max_nnz, "counts")); |
| |
| // Check the decoder does not crash. |
| int32_t value; |
| ASSERT_WP2_OK(sr.ReadWithMax(/*sym=*/0, /*cluster=*/0, /*max_value=*/1, "sym", |
| &value, nullptr)); |
| } |
| #endif |
| |
| } // namespace |
| } // namespace WP2 |