blob: 282e0bbb5e22422abee091ce4764b315a6a9e67d [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Measure wp2 codec format pessimization level by computing the ratio of random
// and mutated bitstreams that decompress fine compared to those that are
// invalid.
#include <cstdio>
#include <limits>
#include <string>
#include <tuple>
#include "imageio/image_dec.h"
#include "include/helpers.h"
#include "src/common/constants.h"
#include "src/common/header_enc_dec.h"
#include "src/dec/tile_dec.h"
#include "src/dec/wp2_dec_i.h"
#include "src/enc/analysis.h"
#include "src/enc/wp2_enc_i.h"
#include "src/utils/data_source.h"
#include "src/utils/random.h"
#include "src/utils/utils.h"
#include "src/wp2/base.h"
#include "src/wp2/decode.h"
#include "src/wp2/encode.h"
// Define WP2_ERROR_TRACE to gather stats about where the random bitstreams fail
// to decode. Do not use multithreading, even only at tile or test level.
#if defined(WP2_TRACE) || defined(WP2_ERROR_TRACE)
#include <unordered_map>
#include "examples/example_utils.h"
class ErrorTracer : public WP2ErrorTracer {
public:
ErrorTracer() { wp2_error_tracer = this; }
~ErrorTracer() {
wp2_error_tracer = nullptr;
WriteToFile("/tmp/wp2_error_trace.txt");
}
void Log(const char* const file, int line, WP2Status, const char*) override {
// Lighten the callstack.
const std::string call =
WP2::GetFileName(file) + ":" + std::to_string(line);
if (!callstack_.empty()) callstack_ += ", ";
callstack_ += call; // Reversed call order but more convenient text output.
}
std::string GetShortCallstack() const { return callstack_.substr(0, 70); }
void Flush() {
++callstack_count_[callstack_];
callstack_.clear();
}
void WriteToFile(const char* const file_path) {
FILE* const file = fopen(file_path, "w");
if (file != nullptr) {
for (const auto& it : callstack_count_) {
fprintf(file, "%u, %s\n", it.second, it.first.c_str());
}
fclose(file);
}
}
private:
std::string callstack_;
std::unordered_map<std::string, uint32_t> callstack_count_;
};
#else
class ErrorTracer {
public:
std::string GetShortCallstack() const { return ""; }
void Flush() {}
};
#endif // defined(WP2_TRACE) || defined(WP2_ERROR_TRACE)
namespace WP2 {
namespace {
//------------------------------------------------------------------------------
// Inspired from GlobalAnalysis() but analysis-driven values are replaced by
// random ones.
WP2Status GenerateRandomGlobalParams(const EncoderConfig& config,
bool image_has_alpha,
UniformIntDistribution* const random,
GlobalParams* const gparams) {
// TODO(yguyon): Also test fully random GlobalParams.
gparams->Reset();
gparams->type_ = DecideGlobalParamsType(config);
gparams->has_alpha_ = image_has_alpha && random->FlipACoin();
if (gparams->type_ == GlobalParams::GP_LOSSY ||
gparams->type_ == GlobalParams::GP_BOTH) {
gparams->partition_set_ = config.partition_set;
gparams->partition_snapping_ = config.partition_snapping;
gparams->explicit_segment_ids_ =
(config.segment_id_mode == WP2::EncoderConfig::SEGMENT_ID_AUTO)
? (config.quality > 63.f)
: (config.segment_id_mode ==
WP2::EncoderConfig::SEGMENT_ID_EXPLICIT);
gparams->use_rnd_mtx_ = config.use_random_matrix;
WP2_CHECK_STATUS(gparams->transf_.Init(config.csp_type));
const uint32_t num_segments = random->Get(1, config.segments);
WP2_CHECK_ALLOC_OK(gparams->segments_.resize(num_segments));
for (uint32_t i = 0; i < num_segments; ++i) {
gparams->segments_[i].risk_ = random->Get(0, 127) / 127.f;
gparams->segments_[i].risk_class_ = i;
gparams->segments_[i].SetQuantizationFactor(
gparams->transf_, kDefaultQuantOffset, kDefaultQuantOffset,
/*qfactor=*/random->Get(1, 128) / 1.28f);
gparams->segments_[i].use_grain_ = false; // Nah.
}
gparams->maybe_use_lossy_alpha_ =
gparams->has_alpha_ && (config.effort > 0) &&
(config.alpha_quality <= kMaxLossyQuality) && random->FlipACoin();
gparams->enable_alpha_filter_ = gparams->maybe_use_lossy_alpha_;
const float quant_factor =
99.f * (kMaxLossyQuality - config.alpha_quality) / kMaxLossyQuality;
gparams->segments_.front().SetAlphaQuantizationFactor(quant_factor);
WP2_CHECK_STATUS(gparams->InitFixedPredictors());
gparams->yuv_filter_magnitude_ = random->Get(0u, kMaxYuvFilterMagnitude);
}
return WP2_STATUS_OK;
}
//------------------------------------------------------------------------------
// Generates a random wp2 bitstream header to 'output'.
// It contains the bitstream features and a semi-random GlobalParams.
WP2Status RandomEncodeBeforeTile(uint32_t width, uint32_t height,
const EncoderConfig& config,
UniformIntDistribution* const random,
Writer* const output) {
EncoderConfig final_config = config;
final_config.tile_shape = TILE_SHAPE_SQUARE_512;
WP2_CHECK_OK(final_config.IsValid(), WP2_STATUS_INVALID_CONFIGURATION);
WP2_CHECK_OK(!final_config.create_preview, WP2_STATUS_INVALID_CONFIGURATION);
const uint32_t rgb_bit_depth = random->FlipACoin() ? 10 : 8;
const bool has_alpha = random->FlipACoin();
// Append proxy header data, only config and dimensions are relevant.
WP2_CHECK_STATUS(EncodeHeader(final_config, width, height, rgb_bit_depth,
has_alpha, /*is_anim=*/false,
/*loop_forever=*/true, kDefaultBackgroundColor,
/*preview_color=*/RGB12b(), /*has_icc=*/false,
/*has_trailing_data=*/false, output));
// Append randomly generated global parameters.
GlobalParams gparams;
WP2_CHECK_STATUS(
GenerateRandomGlobalParams(config, has_alpha, random, &gparams));
WP2_CHECK_STATUS(EncodeGLBL(config, gparams, has_alpha, output));
return WP2_STATUS_OK;
}
// Generates a random wp2 tile bitstream to 'output'.
// It contains a randomly-sized tile chunk filled with random bytes.
WP2Status RandomEncodeTile(uint32_t width, uint32_t height,
UniformIntDistribution* const random,
Writer* const output) {
// Append a randomly sized random bitstream representing a single tile.
// Testing several tiles at once would not bring useful information as they
// are isolated. Clamp to ~4 bytes per pixel.
const uint32_t num_bytes = random->Get(10u, width * height * 4u);
{
uint8_t buf[kMaxVarIntLength];
const uint32_t buf_size = WriteVarInt(num_bytes, 1, kMaxChunkSize, buf);
WP2_CHECK_ALLOC_OK(output->Append(buf, buf_size));
}
Vector_u8 bytes;
WP2_CHECK_ALLOC_OK(bytes.resize(num_bytes));
for (uint8_t& byte : bytes) byte = random->Get<uint8_t>(0, 255);
WP2_CHECK_ALLOC_OK(output->Append(bytes.data(), bytes.size()));
return WP2_STATUS_OK;
}
//------------------------------------------------------------------------------
// Generates as many random bytes as requested.
class RandomDataSource : public DataSource {
public:
explicit RandomDataSource(UniformIntDistribution* const random,
Tile* const tile)
: random_(random), tile_(tile) {}
UniformIntDistribution* const random_;
Tile* const tile_;
Vector_u8 bytes_;
private:
// See DataSource.
bool Fetch(size_t num_requested_bytes) override {
tile_->chunk_size =
num_discarded_bytes_ + num_read_bytes_ + num_requested_bytes;
if (!bytes_.resize(tile_->chunk_size)) return false;
for (size_t i = num_discarded_bytes_ + num_available_bytes_;
i < tile_->chunk_size; ++i) {
bytes_[i] = random_->Get<uint8_t>(0, 255);
}
available_bytes_ = bytes_.data() + num_discarded_bytes_;
num_available_bytes_ = num_read_bytes_ + num_requested_bytes;
return true;
}
void OnDiscard(size_t) override {}
};
//------------------------------------------------------------------------------
// For command line output.
struct BitstreamStats {
size_t length = 0; // Of the bitstream, header included.
BitstreamFeatures features; // Extracted header.
bool is_valid = false; // Can be decoded without an error.
std::string callstack; // Only filled if WP2_ERROR_TRACE.
static void PrintTitleRow() {
printf("Length Width Height Bpp Valid\n");
}
void PrintRow() const {
printf("%7u, %5u, %6u, %6.2f, %s", (uint32_t)length, features.width,
features.height, length / (float)(features.width * features.height),
is_valid ? "Yes" : "No");
if (!callstack.empty()) printf(", %s", callstack.c_str());
printf("\n");
}
};
//------------------------------------------------------------------------------
// From a wp2 'bitstream', checks if it 'is_valid' and can extract 'stats'.
WP2Status GetBistreamStats(const uint8_t bitstream[], size_t bitstream_length,
bool* const is_valid,
BitstreamStats* const stats = nullptr) {
ArgbBuffer buffer;
const WP2Status status = Decode(bitstream, bitstream_length, &buffer);
if (status != WP2_STATUS_BITSTREAM_ERROR) WP2_CHECK_STATUS(status);
*is_valid = (status == WP2_STATUS_OK);
if (stats != nullptr) {
stats->length = bitstream_length;
WP2_CHECK_STATUS(stats->features.Read(bitstream, bitstream_length));
stats->is_valid = *is_valid;
}
return WP2_STATUS_OK;
}
// Decodes the BitstreamFeatures and GlobalParams from the given 'output', then
// generates as many bytes as necessary until LossyDecode() or LosslessDecode()
// returns. Appends these bytes to 'output' even if there was an error.
WP2Status RandomEncodeDecodeTile(uint32_t width, uint32_t height,
UniformIntDistribution* const random,
MemoryWriter* const output,
bool* const is_valid,
BitstreamStats* const stats = nullptr) {
// Setup input and config.
ExternalDataSource header_data(output->mem_, output->size_);
DecoderConfig config = DecoderConfig::kDefault;
config.incremental_mode = DecoderConfig::IncrementalMode::FULL_TILE;
WP2DecDspInit();
// Decode the header.
BitstreamFeatures features;
WP2_CHECK_STATUS(features.Read(output->mem_, output->size_));
header_data.MarkNumBytesAsRead(features.header_size);
GlobalParams gparams;
WP2_CHECK_STATUS(DecodeGLBL(&header_data, config, features, &gparams));
// Setup the decoding environment.
TilesLayout tiles_layout;
tiles_layout.num_tiles_x = tiles_layout.num_tiles_y = 1;
tiles_layout.tile_width = width;
tiles_layout.tile_height = height;
tiles_layout.num_assignable_tiles = 1;
WP2_CHECK_ALLOC_OK(tiles_layout.tiles.resize(1));
tiles_layout.gparams = &gparams;
Tile& tile = tiles_layout.tiles.front();
tile.rect = {0, 0, width, height};
tile.chunk_size_is_known = false;
tile.chunk_size = 0; // Updated by the RandomDataSource because it is checked
// in LossyDecode() and LosslessDecode().
// Generate as many bytes as needed.
RandomDataSource tile_data(random, &tile);
ANSDec dec(&tile_data);
WP2Status status;
if (gparams.type_ == GlobalParams::GP_BOTH ||
gparams.type_ == GlobalParams::GP_AV1) {
*is_valid = true;
return WP2_STATUS_OK; // Ignore neural decoding or AV1.
} else if (gparams.type_ == GlobalParams::GP_LOSSY) {
WP2_CHECK_STATUS(
tile.yuv_output.Resize(width, height, kPredWidth, gparams.has_alpha_));
status = LossyDecode(features, config, &tiles_layout, &dec, &tile);
} else {
assert(gparams.type_ == GlobalParams::GP_LOSSLESS);
WP2_CHECK_STATUS(tile.rgb_output.Resize(width, height));
status = LosslessDecode(features, config, gparams, &dec, &tile);
}
// Append the generated tile to 'output' for accurate stats.
const size_t tile_size = tile_data.bytes_.size();
WP2_CHECK_OK(tile_size > 0 && tile_size <= kMaxChunkSize,
WP2_STATUS_BITSTREAM_ERROR);
uint8_t buf[kMaxVarIntLength];
const uint32_t buf_size = WriteVarInt(tile_size, 1, kMaxChunkSize, buf);
WP2_CHECK_ALLOC_OK(output->Append(buf, buf_size));
WP2_CHECK_ALLOC_OK(
output->Append(tile_data.bytes_.data(), tile_data.bytes_.size()));
// Make sure the error is not a real issue.
if (status != WP2_STATUS_BITSTREAM_ERROR) WP2_CHECK_STATUS(status);
*is_valid = (status == WP2_STATUS_OK);
if (stats != nullptr) {
stats->length = output->size_;
stats->features = features;
stats->is_valid = *is_valid;
}
return WP2_STATUS_OK;
}
//------------------------------------------------------------------------------
// Creates 'num_tries' random bitstreams and counts the number of valid ones.
WP2Status GetPessimizationProportion(
size_t seed_generator_seed, const EncoderConfig& config,
bool fixed_size_tile, uint32_t num_tries, uint32_t* const num_valid,
std::vector<BitstreamStats>* const all_stats = nullptr) {
ErrorTracer error_tracer;
UniformIntDistribution seed_generator(seed_generator_seed);
if (num_valid != nullptr) *num_valid = 0;
for (uint32_t i = 0; i < num_tries; ++i) {
const size_t seed = seed_generator.Get(0u, 1234567890u);
UniformIntDistribution random(seed);
bool is_valid;
if (all_stats != nullptr) all_stats->emplace_back();
BitstreamStats* const stats =
(all_stats != nullptr) ? &all_stats->back() : nullptr;
MemoryWriter writer;
const uint32_t width = random.Get(1, 512);
const uint32_t height = random.Get(1, 512);
WP2_CHECK_STATUS(
RandomEncodeBeforeTile(width, height, config, &random, &writer));
if (fixed_size_tile) {
WP2_CHECK_STATUS(RandomEncodeTile(width, height, &random, &writer));
WP2_CHECK_STATUS(
GetBistreamStats(writer.mem_, writer.size_, &is_valid, stats));
} else {
WP2_CHECK_STATUS(RandomEncodeDecodeTile(width, height, &random, &writer,
&is_valid, stats));
}
if (stats != nullptr) stats->callstack = error_tracer.GetShortCallstack();
error_tracer.Flush();
if (num_valid != nullptr && is_valid) ++*num_valid;
}
return WP2_STATUS_OK;
}
//------------------------------------------------------------------------------
// Tests that a valid bitstream, well, is.
TEST(PessimizationTest, NotRandom) {
constexpr const char file_name[] = "source1_64x48.png";
ArgbBuffer src;
ASSERT_WP2_OK(ReadImage(testutil::GetTestDataPath(file_name).c_str(), &src));
MemoryWriter writer;
ASSERT_WP2_OK(Encode(src, &writer));
bool is_valid;
ASSERT_WP2_OK(GetBistreamStats(writer.mem_, writer.size_, &is_valid));
EXPECT_TRUE(is_valid);
}
// Try different configurations.
class PessimizationTest : public testing::TestWithParam<
std::tuple<size_t, float, bool, uint32_t, bool>> {
};
TEST_P(PessimizationTest, RandomBitstream) {
const size_t seed = std::get<0>(GetParam());
EncoderConfig config = EncoderConfig::kDefault;
config.quality = config.alpha_quality = std::get<1>(GetParam());
const bool fixed_tile_size = std::get<2>(GetParam());
const uint32_t num_tries = std::get<3>(GetParam());
const bool print_stats = std::get<4>(GetParam());
uint32_t num_valid;
std::vector<BitstreamStats> all_stats;
ASSERT_WP2_OK(GetPessimizationProportion(seed, config, fixed_tile_size,
num_tries, &num_valid,
print_stats ? &all_stats : nullptr));
const float percentage_of_valid_bitstreams = 100.f * num_valid / num_tries;
EXPECT_GE(percentage_of_valid_bitstreams,
fixed_tile_size ? 0.f
: (config.quality <= kMaxLossyQuality) ? 50.f
: 40.f);
if (print_stats) {
BitstreamStats::PrintTitleRow();
for (const BitstreamStats& stats : all_stats) stats.PrintRow();
printf("Valid bitstreams: %u/%u (%.3f %%)\n", num_valid, num_tries,
percentage_of_valid_bitstreams);
}
}
INSTANTIATE_TEST_SUITE_P(
PessimizationTestInstantiation, PessimizationTest,
testing::Combine(testing::Values(42), // seed
testing::Values(75.f, 100.f), // quality
testing::Values(true, false), // fixed tile size
testing::Values(10), // num tries
testing::Values(false) // print stats
));
//------------------------------------------------------------------------------
// For a given wp2 'bitstream', returns the position of the first tile byte
// (corresponding to the first byte of the variable integer of the tile chunk
// size).
WP2Status GetTileBytesPosition(const uint8_t bitstream[],
size_t bitstream_length,
size_t* const tile_bytes_position,
uint32_t* const tile_width,
uint32_t* const tile_height) {
ExternalDataSource data_source(bitstream, bitstream_length);
BitstreamFeatures features;
WP2_CHECK_STATUS(DecodeHeader(&data_source, &features));
WP2_CHECK_STATUS(SkipPreview(&data_source, features));
WP2_CHECK_STATUS(SkipICC(&data_source, features));
AnimationFrame frame;
uint32_t frame_index = 0;
WP2_CHECK_STATUS(DecodeANMF(DecoderConfig::kDefault, &data_source, features,
frame_index, &frame));
GlobalParams gparams;
WP2_CHECK_STATUS(
DecodeGLBL(&data_source, DecoderConfig::kDefault, features, &gparams));
*tile_bytes_position =
data_source.GetNumDiscardedBytes() + data_source.GetNumReadBytes();
*tile_width = features.tile_width;
*tile_height = features.tile_height;
return WP2_STATUS_OK;
}
// Mutates one byte randomly located in the 'bytes'.
void RandomByteChange(uint8_t bytes[], size_t num_bytes,
UniformIntDistribution* const random) {
assert(num_bytes > 0);
const uint32_t i = random->Get(0u, (uint32_t)num_bytes - 1u);
bytes[i] = random->Get<uint8_t>(0, 255);
}
// Randomly changes something in the 'tile_bytes' while preserving the integrity
// of the chunk size.
WP2Status MutateBitstream(uint32_t tile_width, uint32_t tile_height,
uint8_t tile_bytes[], size_t tile_num_bytes,
UniformIntDistribution* const random) {
ExternalDataSource data_source(tile_bytes, tile_num_bytes);
const uint32_t max_size = kMaxNumBytesPerPixel * tile_width * tile_height;
uint32_t tile_chunk_size;
WP2_CHECK_STATUS(TryReadVarInt(&data_source, 1, max_size, &tile_chunk_size));
const size_t var_int_num_bytes = data_source.GetNumReadBytes();
WP2_CHECK_OK(var_int_num_bytes + tile_chunk_size == tile_num_bytes,
WP2_STATUS_NOT_ENOUGH_DATA);
RandomByteChange(tile_bytes + var_int_num_bytes,
tile_num_bytes - var_int_num_bytes, random);
// TODO(yguyon): Add more operations
return WP2_STATUS_OK;
}
//------------------------------------------------------------------------------
// Mutate a valid bistream rather than generating completely random ones.
class ImageMutationTest
: public testing::TestWithParam<
std::tuple<const char*, size_t, float, uint32_t, bool>> {};
TEST_P(ImageMutationTest, FromValidImage) {
const char* const file_name = std::get<0>(GetParam());
const size_t seed = std::get<1>(GetParam());
EncoderConfig config = EncoderConfig::kDefault;
config.quality = config.alpha_quality = std::get<2>(GetParam());
if (config.tile_shape == TILE_SHAPE_AUTO) {
config.tile_shape = TILE_SHAPE_SQUARE_512;
}
const uint32_t max_tile_size = kMaxTileSize;
const uint32_t num_tries = std::get<3>(GetParam());
const bool print_stats = std::get<4>(GetParam());
MemoryWriter writer;
{
// Read and encode an image, only keep the bitstream.
ArgbBuffer src;
ASSERT_WP2_OK(
ReadImage(testutil::GetTestDataPath(file_name).c_str(), &src));
ASSERT_WP2_OK(src.SetView(src, {0, 0, std::min(src.width(), max_tile_size),
std::min(src.height(), max_tile_size)}));
ASSERT_WP2_OK(Encode(src, &writer, config));
}
bool is_valid = false;
ASSERT_WP2_OK(GetBistreamStats(writer.mem_, writer.size_, &is_valid));
ASSERT_TRUE(is_valid); // No mutation happened yet.
// Do not mutate anything prior to the tile data.
size_t tile_bytes_position = 0;
uint32_t tile_width = 0;
uint32_t tile_height = 0;
ASSERT_WP2_OK(GetTileBytesPosition(writer.mem_, writer.size_,
&tile_bytes_position, &tile_width,
&tile_height));
// Mutate the valid bitstream 'num_tries' times (start from the original each
// time).
UniformIntDistribution random(seed);
uint32_t num_valid = 0;
Data data;
ErrorTracer error_tracer;
for (uint32_t i = 0; i < num_tries; ++i) {
ASSERT_WP2_OK(data.CopyFrom(writer.mem_, writer.size_));
ASSERT_WP2_OK(MutateBitstream(tile_width, tile_height,
data.bytes + tile_bytes_position,
data.size - tile_bytes_position, &random));
ASSERT_WP2_OK(GetBistreamStats(data.bytes, data.size, &is_valid));
error_tracer.Flush();
if (is_valid) ++num_valid;
}
const float percentage_of_valid_bitstreams = 100.f * num_valid / num_tries;
EXPECT_GE(percentage_of_valid_bitstreams, 0.f);
if (print_stats) {
printf("Valid mutations: %u/%u (%.3f %%), tile bytes: %u/%u\n", num_valid,
num_tries, percentage_of_valid_bitstreams,
(uint32_t)(writer.size_ - tile_bytes_position),
(uint32_t)writer.size_);
}
}
INSTANTIATE_TEST_SUITE_P(ImageMutationTestInstantiationLossy, ImageMutationTest,
testing::Combine(testing::Values("source1_64x48.png",
"source1_32x32.png"),
testing::Values(42), // seed
testing::Values(75.f), // quality
testing::Values(50), // num tries
testing::Values(false) // print stats
));
INSTANTIATE_TEST_SUITE_P(ImageMutationTestInstantiationLossless,
ImageMutationTest,
testing::Combine(testing::Values("source1_64x48.png",
"source1_32x32.png"),
testing::Values(12), // seed
testing::Values(100.f), // quality
testing::Values(50), // num tries
testing::Values(false) // print stats
));
//------------------------------------------------------------------------------
} // namespace
} // namespace WP2