blob: 1f9035d38a65ccf5c8609d8f75b338e23cf96154 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -----------------------------------------------------------------------------
//
// WP2 lossy decoding of residuals.
//
// Author: Skal (pascal.massimino@gmail.com)
#include "src/common/constants.h"
#include "src/dec/wp2_dec_i.h"
#include "src/utils/ans_utils.h"
#include "src/utils/utils.h"
#include "src/wp2/decode.h"
namespace WP2 {
//------------------------------------------------------------------------------
WP2Status ResidualReader::ReadHeaderForResidualSymbols(uint32_t num_coeffs_max,
Channel channel,
SymbolReader* const sr) {
bool is_maybe_used[kResidual1Max + 1];
for (EncodingMethod method :
{EncodingMethod::kMethod0, EncodingMethod::kMethod1}) {
std::string label;
const uint32_t cluster = GetCluster(channel, num_channels_, method);
// Read the dictionaries for the number of consecutive zeros.
const uint32_t method_index = GetMethodIndex(method);
(void)method_index;
label = WP2SPrint("%s_num_zeros_%d", kChannelStr[channel], method_index);
WP2_CHECK_STATUS(sr->ReadHeader(kSymbolResNumZeros, cluster, num_coeffs_max,
label.c_str()));
// Read the dictionaries for small residuals.
label = WP2SPrint("%s_bits0_%d", kChannelStr[channel], method_index);
WP2_CHECK_STATUS(
sr->ReadHeader(kSymbolBits0, cluster, num_coeffs_max, label.c_str()));
sr->GetPotentialUsage(kSymbolBits0, cluster, is_maybe_used,
kResidual1Max + 1);
if (is_maybe_used[kResidual1Max]) {
// Read the dictionaries for prefixes of big residuals.
label = WP2SPrint("%s_bits1_%d", kChannelStr[channel], method_index);
WP2_CHECK_STATUS(
sr->ReadHeader(kSymbolBits1, cluster, num_coeffs_max, label.c_str()));
}
}
return WP2_STATUS_OK;
}
WP2Status ResidualReader::ReadHeader(SymbolReader* const sr,
uint32_t num_coeffs_max_y,
uint32_t num_coeffs_max_uv,
uint32_t num_transforms, bool has_alpha,
bool has_lossy_alpha) {
num_channels_ = (has_alpha ? 4 : 3);
if (use_aom_coeffs_) {
for (Symbol sym : kSymbolsForAOMCoeffs) {
WP2_CHECK_STATUS(sr->ReadHeader(sym, num_transforms, "aom_symbols"));
}
} else {
for (Channel channel : {kYChannel, kUChannel, kVChannel, kAChannel}) {
if (channel == kAChannel && !has_lossy_alpha) continue;
WP2_CHECK_STATUS(ReadHeaderForResidualSymbols(
(channel == kYChannel || channel == kAChannel) ? num_coeffs_max_y
: num_coeffs_max_uv,
channel, sr));
}
for (Symbol sym : kSymbolsForCoeffs) {
WP2_CHECK_STATUS(
sr->ReadHeader(sym, num_coeffs_max_y, "residual_symbols"));
}
for (Symbol sym : kSymbolsForCoeffsPerTf) {
WP2_CHECK_STATUS(sr->ReadHeader(sym, num_transforms, "block_symbols"));
}
}
if (has_lossy_alpha) {
WP2_CHECK_STATUS(sr->ReadHeader(kSymbolEncodingMethodA, num_transforms,
"coeff_method_alpha"));
}
WP2_CHECK_STATUS(
sr->ReadHeader(kSymbolHasCoeffs, num_transforms, "has_coeffs"));
WP2_CHECK_STATUS(
sr->ReadHeader(kSymbolEncodingMethod, num_transforms, "coeff_method"));
return WP2_STATUS_OK;
}
//------------------------------------------------------------------------------
static int32_t ReadDC(Channel channel, uint32_t num_channels, bool can_be_zero,
SymbolReader* const sr) {
uint32_t n =
sr->Read(kSymbolDC, ResidualIO::GetCluster(channel, num_channels), "DC");
if (!can_be_zero) ++n;
// Converting the unsigned value to a signed one.
return (n & 1) ? -(int32_t)((n + 1) >> 1) : (int32_t)(n >> 1);
}
inline void ReadBounds(Channel channel, uint32_t num_channels,
EncodingMethod method, TrfSize tdim, bool is_x_first,
SymbolReader* const sr, bool* const use_bounds2,
uint32_t* const val1, uint32_t* const val2) {
uint8_t min1, max1;
if (is_x_first) {
ResidualBoxAnalyzer::GetRangeX(tdim, &min1, &max1);
} else {
ResidualBoxAnalyzer::GetRangeY(tdim, &min1, &max1);
}
*val1 = sr->dec()->ReadRange(min1, max1, "bound1");
uint8_t min2, max2;
if (is_x_first) {
ResidualBoxAnalyzer::GetRangePerX(tdim, *val1, &min2, &max2);
} else {
ResidualBoxAnalyzer::GetRangePerY(tdim, *val1, &min2, &max2);
}
if (min2 != 255 && min2 <= max2) {
*use_bounds2 = sr->Read(
kSymbolResidualUseBound2,
ResidualReader::GetClusterMergedUV(channel, num_channels, method, tdim),
"use_bound2");
} else {
*use_bounds2 = false;
}
if (*use_bounds2) *val2 = sr->dec()->ReadRange(min2, max2, "bound2");
}
WP2Status ResidualReader::ReadCoeffs(Channel channel, const int16_t dequant[],
SymbolReader* const sr,
CodedBlockBase* const cb,
BlockInfo* const info) {
const bool is_uv = (channel == kUChannel || channel == kVChannel);
ANSDebugPrefix prefix(sr->dec(),
is_uv ? "UV" : (channel == kAChannel) ? "A" : "Y");
CodedBlockBase::CodingParams* const params = cb->GetCodingParams(channel);
const BlockSize split_size = GetSplitSize(cb->dim(), params->split_tf);
const uint32_t split_w = BlockWidthPix(split_size);
const uint32_t split_h = BlockHeightPix(split_size);
uint32_t tf_i = 0;
for (uint32_t split_y = 0; split_y < cb->h_pix(); split_y += split_h) {
for (uint32_t split_x = 0; split_x < cb->w_pix(); split_x += split_w) {
int16_t* const coeffs = cb->coeffs_[channel][tf_i];
std::fill(coeffs, coeffs + cb->NumCoeffsPerTransform(channel), 0);
if (use_aom_coeffs_) {
if (cb->num_coeffs_[channel][tf_i] > 0) { // need to read coeffs?
cb->num_coeffs_[channel][tf_i] =
libgav1::AOMResidualReader::ReadCoeffs(
channel, dequant, cb->tdim(channel), params->tf,
cb->IsFirstCoeffDC(channel), sr, coeffs);
}
// TODO(skal): properly initialize method_[]
} else if (cb->method_[channel][tf_i] == EncodingMethod::kAllZero) {
cb->num_coeffs_[channel][tf_i] = 0;
} else {
cb->num_coeffs_[channel][tf_i] =
ReadCoeffsMethod01(channel, dequant, tf_i, sr, cb, info);
}
++tf_i;
}
}
return WP2_STATUS_OK;
}
uint32_t ResidualReader::ReadCoeffsMethod01(
Channel channel, const int16_t dequant[], uint32_t tf_i,
SymbolReader* const sr, CodedBlockBase* const cb, BlockInfo* const info) {
int16_t* const coeffs = cb->coeffs_[channel][tf_i];
const EncodingMethod method = cb->method_[channel][tf_i];
const TrfSize tdim = cb->tdim(channel);
const bool first_is_dc = cb->IsFirstCoeffDC(channel);
uint32_t num_nz_coeffs = 0;
if (first_is_dc) {
// Read the DC.
const bool can_be_zero = (method != EncodingMethod::kDCOnly);
const int32_t dc = ReadDC(channel, num_channels_, can_be_zero, sr);
coeffs[0] = Clamp(dc * dequant[0], -32768, 32767);
num_nz_coeffs = (dc != 0) ? 1 : 0;
if (method == EncodingMethod::kDCOnly) return num_nz_coeffs;
} else {
assert(method != EncodingMethod::kDCOnly);
}
const std::string label = WP2SPrint("C%d", GetMethodIndex(method));
ANSDebugPrefix coeff_prefix(sr->dec(), label.c_str());
const uint32_t bw = TrfWidth[tdim];
const uint32_t bh = TrfHeight[tdim];
bool use_bounds_x, use_bounds_y;
uint32_t max_x, max_y;
uint32_t ind_min = 0u;
bool can_use_bounds_x, can_use_bounds_y;
ResidualBoxAnalyzer::CanUseBounds(tdim, &can_use_bounds_x, &can_use_bounds_y);
if ((can_use_bounds_x || can_use_bounds_y) &&
sr->Read(kSymbolResidualUseBounds,
GetClusterMergedUV(channel, num_channels_, method, tdim),
"use_bounds")) {
bool is_x_first;
if (can_use_bounds_x && !can_use_bounds_y) {
is_x_first = true;
} else if (!can_use_bounds_x && can_use_bounds_y) {
is_x_first = false;
} else {
is_x_first =
sr->Read(kSymbolResidualBound1IsX,
GetClusterMergedUV(channel, num_channels_, method, tdim),
"is_x_first");
}
if (is_x_first) {
use_bounds_x = true;
ReadBounds(channel, num_channels_, method, tdim, /*is_x_first=*/true,
sr, &use_bounds_y, &max_x, &max_y);
if (!use_bounds_y) max_y = bh - 1;
} else {
use_bounds_y = true;
ReadBounds(channel, num_channels_, method, tdim, /*is_x_first=*/false,
sr, &use_bounds_x, &max_y, &max_x);
if (!use_bounds_x) max_x = bw - 1;
}
// Figure out the minimal index we will reach (and therefore the one from
// which we need to store EOB).
uint32_t min_zig_zag_ind_x, min_zig_zag_ind_y;
ResidualBoxAnalyzer::FindBounds(tdim, max_x, max_y, &min_zig_zag_ind_x,
&min_zig_zag_ind_y);
if (use_bounds_x) ind_min = min_zig_zag_ind_x;
if (use_bounds_y) ind_min = std::max(ind_min, min_zig_zag_ind_y);
} else {
use_bounds_x = use_bounds_y = false;
max_x = bw - 1;
max_y = bh - 1;
}
// Debug.
if (info != nullptr) {
if (use_bounds_x) {
info->residual_info[channel][tf_i].push_back("use x bound " +
std::to_string(max_x));
}
if (use_bounds_y) {
info->residual_info[channel][tf_i].push_back("use y bound " +
std::to_string(max_y));
}
}
bool has_written_zeros = false;
bool has_only_ones_left = false;
bool previous_is_a_one = false;
uint32_t sector_cluster;
BoundedResidualIterator iter(tdim, use_bounds_x, use_bounds_y, max_x, max_y);
if (first_is_dc) ++iter; // Skip the DC.
for (; !iter.IsDone();) {
const uint32_t x = iter.x();
const uint32_t y = iter.y();
const uint32_t i = iter.Index();
const uint32_t sector = ResidualIO::GetSector(x, y, tdim);
sector_cluster =
GetClusterMergedUV(channel, num_channels_, method, tdim, sector);
// If we have more than one element left and not written 0s before, check if
// there is any 0 coming.
if (!has_written_zeros && iter.MaxNumCoeffsLeft() > 1 &&
sr->Read(kSymbolResidualIsZero, sector_cluster, "is_zero")) {
// Read the number of consecutive 0s we have, by batches of
// kResidualCons0Max.
++iter;
int32_t num_zeros_tmp;
do {
if (iter.MaxNumCoeffsLeft() < kResidualCons0Max + 1u) {
// If the number of elements left is smaller than the max number of
// possible 0s plus one non-zero element, go with a ramge to force
// feasibility.
num_zeros_tmp =
sr->dec()->ReadRValue(iter.MaxNumCoeffsLeft(), "num_zeros");
} else {
num_zeros_tmp = sr->Read(kSymbolResNumZeros,
GetCluster(channel, num_channels_, method),
"num_zeros");
}
for (int32_t j = 0; j < num_zeros_tmp; ++j) ++iter;
} while (num_zeros_tmp == kResidualCons0Max);
has_written_zeros = true;
continue;
}
has_written_zeros = false;
iter.SetAsNonZero();
uint32_t abs_v;
if (has_only_ones_left ||
sr->Read(kSymbolResidualIsOne, sector_cluster, "is_one")) {
abs_v = 1;
} else {
if (sr->Read(kSymbolResidualIsTwo, sector_cluster, "is_two")) {
abs_v = 2;
} else {
const uint32_t residual1 =
sr->Read(kSymbolBits0, GetCluster(channel, num_channels_, method),
"residual1");
if (residual1 == kResidual1Max) {
abs_v = 3 + kResidual1Max;
const uint32_t prefix =
sr->Read(kSymbolBits1, GetCluster(channel, num_channels_, method),
"residual2_prefix");
const uint32_t extra = sr->dec()->ReadUValue(
PrefixCode::NumExtraBits(prefix, /*prefix_size=*/0),
"residual2_extra");
abs_v += PrefixCode::Merge(prefix, /*prefix_size=*/0, extra);
assert(abs_v <= kMaxCoeffValue);
} else {
abs_v = 3 + residual1;
}
}
}
coeffs[i] = sr->dec()->ReadBool("is_negative") ? -abs_v : abs_v;
// TODO(skal): clamp coeffs[i] to signed-16b?
coeffs[i] *= dequant[x + y * WP2QStride];
++num_nz_coeffs;
const uint32_t zigzag_ind = iter.ZigZagIndex();
++iter;
// Exit if we are at the last element.
if (iter.IsDone()) break;
// Read an End Of Block if we have reached both sides of the box.
if (zigzag_ind >= ind_min && iter.CanEOB() &&
sr->Read(kSymbolResidualEndOfBlock, sector_cluster, "eob")) {
break;
}
if (abs_v == 1 && !has_only_ones_left && !previous_is_a_one) {
has_only_ones_left =
(bool)sr->Read(kSymbolResidualHasOnlyOnesLeft, sector_cluster,
"has_only_ones_left");
}
previous_is_a_one = (abs_v == 1);
}
// Debug.
if (info != nullptr) {
// Compute the real box bounds.
uint32_t real_max_x = 0u, real_max_y = 0u;
const uint32_t max_i = cb->NumCoeffsPerTransform(channel) - 1;
for (uint32_t i = 0u; i <= max_i; ++i) {
const uint32_t x = i % bw;
const uint32_t y = i / bw;
if (coeffs[i] != 0) {
real_max_x = std::max(real_max_x, x);
real_max_y = std::max(real_max_y, y);
}
}
// last_i contains the index of the last element in the whole block.
// last_j contains the index of the last element in the chosen box (there
// can be none).
// last_k contains the index of the last element in the bounding box.
uint32_t last_i = 0u, last_j = 0u, last_k = 0u;
for (uint32_t i = 0u, j = 0u, k = 0u; i <= max_i; ++i) {
const uint32_t x = i % bw;
const uint32_t y = i / bw;
// No need to store anything if we are out of bounds: it is 0s.
if (x > max_x || y > max_y) continue;
if (coeffs[i] != 0) {
last_i = i;
last_j = j;
last_k = k;
}
++j;
if (x <= real_max_x && y <= real_max_y) ++k;
}
assert(last_j <= last_i);
assert(last_k <= last_j);
info->residual_info[channel][tf_i].push_back("last index: " +
std::to_string(last_i));
if (use_bounds_x && use_bounds_y) {
info->residual_info[channel][tf_i].push_back(
"last index in chosen full box: " + std::to_string(last_j));
} else if (use_bounds_x || use_bounds_y) {
info->residual_info[channel][tf_i].push_back(
"last index in chosen box: " + std::to_string(last_j));
info->residual_info[channel][tf_i].push_back("last index in full box: " +
std::to_string(last_k));
} else {
info->residual_info[channel][tf_i].push_back("last index in full box: " +
std::to_string(last_k));
}
}
return num_nz_coeffs;
}
//------------------------------------------------------------------------------
} // namespace WP2