blob: 16cef35bfe80d4ec632340df46d031f19395f54c [file] [log] [blame]
// Copyright (c) the JPEG XL Project
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// -----------------------------------------------------------------------------
//
// Forked from https://github.com/google/pik/blob/master/pik/lossless8.cc
// at 16268ef512a65b541c7b5e485468a7ed33bc13d8
#include <algorithm>
#include <array>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include "src/common/lossless/calic.h"
#include "src/common/lossless/plane.h"
#include "src/common/lossless/scp.h"
#include "src/common/symbols.h"
#include "src/dec/lossless/losslessi_dec.h"
#include "src/dec/symbols_dec.h"
#include "src/dsp/dsp.h"
#include "src/utils/ans.h"
#include "src/utils/plane.h"
#include "src/utils/utils.h"
#include "src/utils/vector.h"
#include "src/wp2/base.h"
namespace WP2L {
int16_t CalcTPVFromPredictionAndDistance(int16_t prediction,
bool use_opposite_error,
int16_t distance, int16_t min,
int16_t max) {
// Compute the maximum error for which the sign is not forced from the
// prediction. E.g., if the prediction is 10 and the value is in [0, 100],
// the signed error can only be in [-10, 10]. Beyond that, the error is
// positive in [0, 80].
const int16_t max_signed_error = std::min(max - prediction, prediction - min);
int16_t res;
if (distance <= 2 * max_signed_error) {
res = distance % 2 == 0 ? -distance / 2 : (distance + 1) / 2;
// End of section 5.D: Error Feedback.
if (use_opposite_error) res = -res;
prediction += res;
} else {
res = distance - max_signed_error;
if (prediction + res > max) {
prediction -= res;
} else {
prediction += res;
}
}
return prediction;
}
////////////////////////////////////////////////////////////////////////////////
namespace scp {
class CalicDecState : public ::WP2L::calic::CalicState {
public:
WP2Status DecompressPlane(int plane_to_decompress, const ScpGlobal& global,
WP2::SymbolReader& sr, ImagePlanes& img) {
const int16_t min_tpv = global.min_tpv[plane_to_decompress];
const int16_t max_tpv = global.max_tpv[plane_to_decompress];
WP2_CHECK_STATUS(SetParameters(
global.calic_quantizations[plane_to_decompress], min_tpv, max_tpv));
for (size_t y = 0; y < img.height(); ++y) {
int16_t* row = img.PlaneRow(plane_to_decompress, y);
StartProcessingLine(y, row);
for (size_t x = 0; x < img.width(); ++x) {
// Predict.
uint8_t ctxt;
bool is_mean_error_negative;
int16_t prediction;
Predict(x, prediction, ctxt, is_mean_error_negative);
// Compute the true pixel value.
const int16_t residual =
sr.Read(/*sym=*/plane_to_decompress, ctxt, "sym");
row[x] = CalcTPVFromPredictionAndDistance(
prediction, is_mean_error_negative, residual, min_tpv, max_tpv);
WP2_CHECK_OK(row[x] >= min_tpv && row[x] <= max_tpv,
WP2_STATUS_BITSTREAM_ERROR);
Update(x, row[x]);
}
}
return WP2_STATUS_OK;
}
};
////////////////////////////////////////////////////////////////////////////////
// Decoder for the classical SCP.
class ClassicalDecState : public ClassicalState {
public:
WP2Status DecompressPlane(int planeToDecompress, const ScpGlobal& global,
WP2::SymbolReader& sr, ImagePlanes& img) {
const int32_t min = global.min_tpv[planeToDecompress];
const int32_t max = global.max_tpv[planeToDecompress];
min_tpv_ = min << kPredExtraBits;
max_tpv_ = max << kPredExtraBits;
for (size_t y = 0; y < img.height(); ++y) {
int16_t* row = img.PlaneRow(planeToDecompress, y);
StartProcessingLine(y, row);
// Set predictor pointer.
const PredictionMode pred_mode =
global.prediction_modes[planeToDecompress];
auto func = y == 0 ? &ClassicalDecState::PredictY0
: pred_mode == PredictionMode::Regular
? &ClassicalDecState::PredictRegular
: pred_mode == PredictionMode::West
? &ClassicalDecState::PredictWest
: pred_mode == PredictionMode::North
? &ClassicalDecState::PredictNorth
: nullptr;
for (size_t x = 0; x < img.width(); ++x) {
// Predict.
uint8_t ctxt;
int16_t prediction = (this->*func)(x, &ctxt);
// Compute the true pixel value.
ctxt >>= global.maxerr_shift;
assert(0 <= ctxt && ctxt <= kNumContexts - 1);
const int context = ctxt;
int q = sr.Read(/*sym=*/planeToDecompress, context, "sym");
row[x] = CalcTPVFromPredictionAndDistance(
prediction, /*use_opposite_error=*/false, q, min, max);
UpdateErrors</*USE_JXL=*/false>(x, row[x], q);
}
}
return WP2_STATUS_OK;
}
};
////////////////////////////////////////////////////////////////////////////////
// Decoder for the JPEG XL SCP.
class JXLDecState : public JXLState {
public:
WP2Status DecompressPlane(int planeToDecompress, const ScpGlobal& global,
WP2::SymbolReader& sr, ImagePlanes& img) {
const int32_t min = global.min_tpv[planeToDecompress];
const int32_t max = global.max_tpv[planeToDecompress];
min_tpv_ = min << kPredExtraBits;
max_tpv_ = max << kPredExtraBits;
for (size_t y = 0; y < img.height(); ++y) {
int16_t* row = img.PlaneRow(planeToDecompress, y);
StartProcessingLine(y, row);
// Set predictor pointer.
auto func = &JXLDecState::Predict;
for (size_t x = 0; x < img.width(); ++x) {
uint8_t ctxt;
int16_t guess = (this->*func)(x, &ctxt);
// Compute the true pixel value.
const int q = sr.Read(/*sym=*/planeToDecompress, ctxt, "sym");
row[x] = CalcTPVFromPredictionAndDistance(
guess, /*use_opposite_error=*/false, q, min, max);
// TODO(vrabaud): do not clamp !!! Those are wasted bits.
row[x] = std::clamp(static_cast<int32_t>(row[x]), min, max);
UpdateErrors</*USE_JXL=*/true>(x, row[x], q);
}
}
return WP2_STATUS_OK;
}
};
////////////////////////////////////////////////////////////////////////////////
// Main function to decompress a buffer using an SCP state.
template <typename STATE>
WP2Status Decompress(const ScpGlobal& global, uint32_t width, uint32_t height,
bool has_alpha, STATE& state, WP2::SymbolReader& sr,
int16_t* out) {
ImagePlanes img;
WP2_CHECK_STATUS(img.Create(width, height, has_alpha));
for (uint32_t c = has_alpha ? 0 : 1; c < 4; ++c) {
WP2_CHECK_STATUS(state.Reset());
WP2_CHECK_STATUS(state.DecompressPlane(c, global, sr, img));
WP2_CHECK_STATUS(sr.dec()->GetStatus());
// Copy back.
for (size_t y = 0; y < height; ++y) {
const int16_t* const WP2_RESTRICT rowImg = img.PlaneRow(c, y);
for (size_t x = 0; x < width; ++x) {
out[y * 4 * width + x * 4 + c] = rowImg[x];
}
}
}
return WP2_STATUS_OK;
}
} // namespace scp
WP2Status Decoder::DecodePlaneCodecHeader(ScpGlobal* global) {
WP2_CHECK_STATUS(GetARGBRanges(global->min_tpv, global->max_tpv));
global->method = static_cast<PlaneCodec::Method>(
dec_->ReadRValue(static_cast<int>(PlaneCodec::Method::kNum), "method"));
std::array<int, 4> num_contexts;
switch (global->method) {
case PlaneCodec::Method::kScpClassical:
global->maxerr_shift = dec_->ReadRValue(5, "maxerrShift");
num_contexts.fill(scp::ClassicalState::GetNumContextsFromErrShift(
global->maxerr_shift));
for (int c = gparams_->has_alpha_ ? 0 : 1; c < 4; ++c) {
global->prediction_modes[c] =
static_cast<scp::PredictionMode>(dec_->ReadRValue(
static_cast<int>(scp::PredictionMode::Num), "pred_mode"));
}
break;
case PlaneCodec::Method::kScpJxl:
global->jxl_header_index =
dec_->ReadRange(0, scp::kJxlHeaders.size(), "header_index");
if (global->jxl_header_index ==
static_cast<int>(scp::kJxlHeaders.size())) {
global->jxl_header.p1c = dec_->ReadRange(0, 40, "p1c");
global->jxl_header.p2c = dec_->ReadRange(0, 40, "p2c");
global->jxl_header.p3ca = dec_->ReadRange(0, 40, "p3ca");
global->jxl_header.p3cb = dec_->ReadRange(0, 40, "p3cb");
global->jxl_header.p3cc = dec_->ReadRange(0, 40, "p3cc");
global->jxl_header.p3cd = dec_->ReadRange(0, 40, "p3cd");
global->jxl_header.p3ce = dec_->ReadRange(0, 40, "p3ce");
for (int i = 0; i < 4; ++i) {
global->jxl_header.w[i] = dec_->ReadRange(0, 40, "w");
}
}
num_contexts.fill(scp::JXLState::GetNumContexts());
break;
case PlaneCodec::Method::kCalic:
for (int c = gparams_->has_alpha_ ? 0 : 1; c < 4; ++c) {
global->calic_quantizations[c] =
static_cast<calic::CalicState::Quantization>(dec_->ReadRValue(
static_cast<int>(calic::CalicState::Quantization::kNum),
"quantization"));
num_contexts[c] = calic::CalicState::GetNumContextsFromQuantization(
global->calic_quantizations[c]);
}
break;
case PlaneCodec::Method::kNum:
default:
assert(false);
return WP2_STATUS_INVALID_PARAMETER;
}
WP2::SymbolsInfo symbols_info;
WP2_CHECK_STATUS(hdr_.sr_.Init(symbols_info, dec_));
for (uint32_t c = gparams_->has_alpha_ ? 0 : 1; c < 4; ++c) {
symbols_info.SetInfo(/*sym=*/c, /*min=*/0,
/*max=*/global->max_tpv[c] - global->min_tpv[c] + 1,
/*num_clusters=*/num_contexts[c],
WP2::SymbolsInfo::StorageMethod::kAuto);
}
WP2_CHECK_STATUS(hdr_.sr_.Init(symbols_info, hdr_.sr_.dec()));
WP2_CHECK_STATUS(hdr_.sr_.Allocate());
for (uint32_t s = 0; s < symbols_info.Size(); ++s) {
for (uint32_t c = 0; c < symbols_info.NumClusters(s); ++c) {
WP2_CHECK_STATUS(
hdr_.sr_.ReadHeader(s, c, tile_->rect.GetArea(), "header"));
}
}
return WP2_STATUS_OK;
}
WP2Status Decoder::DecodePlaneCodec(uint32_t width, uint32_t last_row,
const ScpGlobal& global,
WP2::Planef* bits_per_pixel,
int16_t** src_out) {
int16_t* const data = pixels_.data();
if (last_pixel_ == 0) {
// Decode everything the first time.
const int height = tile_->rect.height;
int16_t* src = data + 4 * last_pixel_;
switch (global.method) {
case PlaneCodec::Method::kScpClassical: {
scp::ClassicalDecState state;
WP2_CHECK_STATUS(state.Init(width, height));
WP2_CHECK_STATUS(Decompress(global, width, height, gparams_->has_alpha_,
state, hdr_.sr_, src));
break;
}
case PlaneCodec::Method::kScpJxl: {
scp::JXLDecState state;
WP2_CHECK_STATUS(state.Init(width, height));
if (global.jxl_header_index ==
static_cast<int>(scp::kJxlHeaders.size())) {
state.SetHeader(global.jxl_header);
} else {
state.SetHeaderIndex(global.jxl_header_index);
}
WP2_CHECK_STATUS(Decompress(global, width, height, gparams_->has_alpha_,
state, hdr_.sr_, src));
break;
}
case PlaneCodec::Method::kCalic: {
scp::CalicDecState state;
WP2_CHECK_STATUS(state.Init(width, height));
WP2_CHECK_STATUS(Decompress(global, width, height, gparams_->has_alpha_,
state, hdr_.sr_, src));
break;
}
case PlaneCodec::Method::kNum:
default:
assert(false);
return WP2_STATUS_INVALID_PARAMETER;
}
for (size_t y = 0; y < last_row; ++y) {
WP2_CHECK_STATUS(ProcessRows(y));
}
}
if (src_out != nullptr) *src_out = data + 4 * width * last_row;
return WP2_STATUS_OK;
}
} // namespace WP2L