blob: 4d502ff2ad35a5b4185d81ba195d0242ce30d553 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -----------------------------------------------------------------------------
//
// Contexted predictions and syntax I/O
//
// Author: Skal (pascal.massimino@gmail.com)
#include "src/common/lossy/context.h"
#include <algorithm>
#include <functional> // for std::greater<>
#include "src/common/lossy/block.h"
#include "src/common/lossy/predictor.h"
#include "src/dec/symbols_dec.h"
#include "src/enc/symbols_enc.h"
namespace WP2 {
constexpr uint32_t kMapHeight = kMaxBlockSize + 1;
//------------------------------------------------------------------------------
// BlockContextPredictor
//------------------------------------------------------------------------------
WP2Status BlockContextPredictor::Init(uint8_t symbol, uint32_t symbol_range,
uint32_t width) {
WP2_CHECK_OK(symbol_range <= kMaxRange, WP2_STATUS_INVALID_PARAMETER);
symbol_ = symbol;
range_ = symbol_range;
WP2_CHECK_ALLOC_OK(map_.resize(SizeBlocks(width) * kMapHeight));
ctxt_.Init(&map_, SizeBlocks(width), kMapHeight);
return WP2_STATUS_OK;
}
void BlockContextPredictor::SetValue(const CodedBlockBase& cb, uint8_t value) {
ctxt_.Set(cb.blk(), value);
}
WP2Status BlockContextPredictor::CopyFrom(const BlockContextPredictor& other) {
symbol_ = other.symbol_;
range_ = other.range_;
WP2_CHECK_ALLOC_OK(map_.copy_from(other.map_));
const uint32_t width = (uint32_t)map_.size() / kMapHeight;
assert(width * kMapHeight == other.map_.size());
ctxt_.Init(&map_, width, kMapHeight);
// Nothing else to copy in 'ctxt_' as it only points to 'map_'.
return WP2_STATUS_OK;
}
void BlockContextPredictor::WriteValue(const CodedBlockBase& cb, uint8_t value,
SymbolManager* const sm,
ANSEncBase* const enc,
WP2_OPT_LABEL) const {
sm->Process(symbol_, GetSlot(cb, value), label, enc);
}
uint8_t BlockContextPredictor::ReadValue(const CodedBlockBase& cb,
SymbolReader* const sr,
WP2_OPT_LABEL) {
const uint32_t slot = sr->Read(symbol_, label);
uint8_t context[kMaxNumSegments];
GetContext(cb, context);
return context[slot];
}
void BlockContextPredictor::GetContext(const CodedBlockBase& cb,
uint8_t* const context) const {
uint8_t context_tmp[kMaxContextSize];
int8_t y_left_occ, y_right_occ;
cb.GetOccupancy(&y_left_occ, &y_right_occ, /*top_context_extent=*/nullptr);
const uint32_t size =
ctxt_.GetContext(cb.blk(), y_left_occ, y_right_occ, context_tmp);
// we use the lower 8bit to store the id, the upper 8 to store the counts
uint16_t counts[kMaxNumSegments] = {0};
for (uint32_t i = 0; i < size; ++i) counts[context_tmp[i]] += (1 << 8);
for (uint32_t v = 0; v < range_; ++v) counts[v] |= v;
std::sort(counts, counts + range_, std::greater<int>());
for (uint32_t i = 0; i < range_; ++i) {
context[i] = counts[i] & 0xff; // retrieve the id
}
}
uint32_t BlockContextPredictor::GetSlot(const CodedBlockBase& cb,
uint8_t id) const {
uint8_t id_ctxt[kMaxRange];
GetContext(cb, id_ctxt);
for (uint32_t slot = 0; slot < range_; ++slot) {
if (id == id_ctxt[slot]) return slot;
}
assert(0); // shouldn't be reached
return 0;
}
//------------------------------------------------------------------------------
// SegmentIdPredictor
//------------------------------------------------------------------------------
void SegmentIdPredictor::InitInitialSegmentId(const CodedBlockBase& cb,
uint32_t id) {
predictor_.SetValue(cb, id);
}
WP2Status SegmentIdPredictor::CopyFrom(const SegmentIdPredictor& other) {
num_segments_ = other.num_segments_;
explicit_segment_ids_ = other.explicit_segment_ids_;
WP2_CHECK_STATUS(predictor_.CopyFrom(other.predictor_));
return WP2_STATUS_OK;
}
//------------------------------------------------------------------------------
// Writing
WP2Status SegmentIdPredictor::InitWrite(uint32_t num_segments,
bool explicit_segment_ids,
uint32_t width) {
assert(num_segments <= kMaxNumSegments);
WP2_CHECK_STATUS(predictor_.Init(kSymbolSegmentId, num_segments, width));
num_segments_ = num_segments;
explicit_segment_ids_ = explicit_segment_ids;
return WP2_STATUS_OK;
}
WP2Status SegmentIdPredictor::WriteHeader(ANSEncBase* const enc) const {
return WP2_STATUS_OK;
}
void SegmentIdPredictor::WriteId(const CodedBlockBase& cb,
SymbolManager* const sm,
ANSEncBase* const enc) const {
if (explicit_segment_ids_ && !cb.blk().IsSmall()) {
predictor_.WriteValue(cb, cb.id_, sm, enc, "segment_id");
} else {
assert(cb.id_ == GetIdFromSize(num_segments_, cb.dim()));
}
}
void SegmentIdPredictor::RecordId(const CodedBlockBase& cb) {
predictor_.SetValue(cb, cb.id_); // Fill the segment id map.
}
//------------------------------------------------------------------------------
// Reading
WP2Status SegmentIdPredictor::ReadHeader(ANSDec* const dec,
uint32_t num_segments,
bool explicit_segment_ids,
uint32_t width) {
(void)dec;
num_segments_ = num_segments;
explicit_segment_ids_ = explicit_segment_ids;
WP2_CHECK_STATUS(predictor_.Init(kSymbolSegmentId, num_segments, width));
// For now, we consider the statistics as being the same for the whole image.
return WP2_STATUS_OK;
}
void SegmentIdPredictor::ReadId(SymbolReader* const sr,
CodedBlockBase* const cb) {
if (explicit_segment_ids_ && !cb->blk().IsSmall()) {
cb->id_ = predictor_.ReadValue(*cb, sr, "segment_id");
} else {
cb->id_ = GetIdFromSize(num_segments_, cb->dim());
}
assert(cb->id_ < num_segments_);
predictor_.SetValue(*cb, cb->id_); // Fill the segment ids.
}
uint8_t SegmentIdPredictor::GetIdFromSize(uint32_t num_segments,
BlockSize block_size) {
// Map [4x4:32x32] to [0:3].
const uint8_t segment_id = (uint8_t)WP2Log2Floor(
(BlockWidth[block_size] + BlockHeight[block_size]) / 2);
// Scale it to the actual number of segments.
return DivRound<uint8_t>(segment_id * (num_segments - 1), 3);
}
//------------------------------------------------------------------------------
// AlphaModePredictor
//------------------------------------------------------------------------------
WP2Status AlphaModePredictor::Init(uint32_t width, uint32_t height) {
return predictor_.Init(kSymbolBlockAlphaMode, /*symbol_range=*/3, width);
}
WP2Status AlphaModePredictor::CopyFrom(const AlphaModePredictor& other) {
return predictor_.CopyFrom(other.predictor_);
}
void AlphaModePredictor::Update(const CodedBlockBase& cb) {
predictor_.SetValue(cb, (uint8_t)cb.alpha_mode_);
}
void AlphaModePredictor::Write(const CodedBlockBase& cb, ANSEncBase* const enc,
SymbolManager* const sw) {
predictor_.WriteValue(cb, (uint8_t)cb.alpha_mode_, sw, enc, "alpha_mode");
Update(cb);
}
BlockAlphaMode AlphaModePredictor::Read(const CodedBlockBase& cb,
SymbolReader* const sr) {
const uint32_t mode = predictor_.ReadValue(cb, sr, "alpha_mode");
predictor_.SetValue(cb, (int)mode);
return (BlockAlphaMode)mode;
}
//------------------------------------------------------------------------------
// DCDiffusionMap
WP2Status DCDiffusionMap::Init(uint32_t width) {
const uint32_t bw = SizeBlocks(width);
for (uint32_t i = 0; i < kErrorHeight; ++i) {
WP2_CHECK_ALLOC_OK(error_[i].resize(bw));
}
Clear();
return WP2_STATUS_OK;
}
WP2Status DCDiffusionMap::CopyFrom(const DCDiffusionMap& other) {
for (uint32_t i = 0; i < kErrorHeight; ++i) {
WP2_CHECK_ALLOC_OK(error_[i].copy_from(other.error_[i]));
}
max_row_ = other.max_row_;
return WP2_STATUS_OK;
}
void DCDiffusionMap::Clear() {
max_row_ = 0;
for (auto& l : error_[0]) l = 0.f;
}
void DCDiffusionMap::Store(const FrontMgrBase& mgr, const Block& blk,
int16_t error) {
const uint32_t block_max_row = blk.y() + blk.h();
if (block_max_row > max_row_) {
for (uint32_t i = max_row_ + 1; i <= block_max_row; ++i) {
for (auto& l : error_[i % kErrorHeight]) l = 0.f;
}
max_row_ = block_max_row;
}
const uint32_t left_occupancy =
(blk.x() > 0) ? std::min(mgr.GetOccupancy(blk.x() - 1), block_max_row)
: block_max_row;
const uint32_t width = error_[0].size();
const uint32_t right_occupancy =
(blk.x() + blk.w() < width)
? std::min(mgr.GetOccupancy(blk.x() + blk.w()), block_max_row)
: block_max_row;
const uint32_t border_size = blk.w() + (block_max_row - left_occupancy) +
(block_max_row - right_occupancy);
const float error_per_block = (float)error / border_size;
for (uint32_t y = left_occupancy; y < block_max_row; ++y) {
error_[y % kErrorHeight][blk.x() - 1] += error_per_block;
}
for (uint32_t y = right_occupancy; y < block_max_row; ++y) {
error_[y % kErrorHeight][blk.x() + blk.w()] += error_per_block;
}
for (uint32_t x = blk.x(); x < blk.x() + blk.w(); ++x) {
error_[block_max_row % kErrorHeight][x] += error_per_block;
}
}
int16_t DCDiffusionMap::Get(const Block& blk, uint32_t strength) const {
float new_error = 0;
for (uint32_t y = blk.y(); y < (blk.y() + blk.h()) && y <= max_row_; ++y) {
new_error += error_[y % kErrorHeight][blk.x()];
if (blk.w() > 1) {
new_error += error_[y % kErrorHeight][blk.x() + blk.w() - 1];
}
}
for (uint32_t x = blk.x() + 1; x < (blk.x() + blk.w()) - 1; ++x) {
new_error += error_[blk.y() % kErrorHeight][x];
}
return std::round(new_error * strength / 255.f);
}
} // namespace WP2