| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ----------------------------------------------------------------------------- |
| // |
| // Predictors |
| // |
| // Author: Skal (pascal.massimino@gmail.com) |
| // |
| |
| #include "src/common/lossy/predictor.h" |
| |
| #include <algorithm> |
| #include <cmath> |
| #include <cstdio> |
| #include <limits> |
| #include <numeric> |
| |
| #include "src/common/constants.h" |
| #include "src/common/lossy/block.h" |
| #include "src/dec/symbols_dec.h" |
| #include "src/dsp/dsp.h" |
| #include "src/dsp/math.h" |
| #include "src/enc/block_enc.h" |
| #include "src/enc/symbols_enc.h" |
| #include "src/utils/utils.h" |
| #include "src/utils/vector.h" |
| |
| #define LOG 0 |
| |
| namespace WP2 { |
| |
| //------------------------------------------------------------------------------ |
| // PredictorVector |
| |
| void PredictorVector::reset() { |
| for (auto& p : *this) delete p; // we own the predictors |
| Vector<Predictor*>::reset(); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // For debugging only. |
| |
| std::string GetContextAndBlockPixelsStr(const int16_t context[], |
| const int16_t context_right[], |
| const int16_t context_left[], |
| uint32_t w, uint32_t h, |
| const int16_t block[], uint32_t step) { |
| std::string str; |
| str += WP2SPrint(" %4d | ", context[h]); |
| for (uint32_t i = 0; i < w; ++i) str += WP2SPrint("%4d ", context[h + 1 + i]); |
| str += WP2SPrint("| %4d", context[h + 1 + w]); |
| uint32_t num_extra_dashes = 0; |
| const uint32_t tr_size = ContextSize(kContextExtendRight, w, h); |
| uint32_t tr_i = h + 1 + w + 1; |
| if (context_right != nullptr) { |
| str += " "; |
| // Display the first extended right context values. |
| for (uint32_t x = 0; x < 7 && tr_i < tr_size; |
| ++x, ++tr_i, ++num_extra_dashes) { |
| str += WP2SPrint("%4d ", context_right[tr_i]); |
| } |
| } |
| str += "\n----------+-"; |
| for (uint32_t i = 0; i < w; ++i) str += "-----"; |
| str += "+-----"; |
| for (uint32_t i = 0; i < num_extra_dashes; ++i) str += "-----"; |
| str += "\n"; |
| for (uint32_t y = 0; y < h || (context_right != nullptr && tr_i < tr_size); |
| ++y) { |
| if (y < h) { |
| if (context_left != nullptr) { |
| str += WP2SPrint("%4d ", context_left[y]); |
| } else { |
| str += " "; |
| } |
| str += WP2SPrint("%4d | ", context[h - 1 - y]); |
| for (uint32_t x = 0; x < w; ++x) { |
| str += WP2SPrint("%4d ", block[x + y * step]); |
| } |
| str += WP2SPrint("| %4d", context[h + 1 + w + 1 + y]); |
| } else if (y == h) { |
| str += WP2SPrint(" ^___| %*s", 12 + w * 5 + 6 - 8, " "); |
| } else { |
| str += WP2SPrint("%*s", 12 + w * 5 + 6, " "); |
| } |
| // Display remaining extended right context in small batches on the right. |
| if (context_right != nullptr) { |
| if (tr_i < tr_size) { |
| str += " "; |
| for (uint32_t x = 0; x < 4 && tr_i < tr_size; ++x, ++tr_i) { |
| str += WP2SPrint(" %4d", context_right[tr_i]); |
| } |
| } |
| } |
| str += "\n"; |
| } |
| return str; |
| } |
| |
| const int16_t kFakeContext[kContextSize] = { |
| 2, 4, 6, 8, // left |
| 10, // top left |
| 12, 14, 16, 18, // top |
| 20, // top-right |
| 22, 24, 26, 28 // right |
| }; |
| |
| //------------------------------------------------------------------------------ |
| |
| // Dummy predictor that always predicts zero. |
| class ZeroPredictor : public Predictor { |
| public: |
| void Predict(const CodedBlockBase& cb, Channel channel, bool split_tf, |
| uint32_t tf_i, int16_t output[], uint32_t step) const override { |
| const BlockSize split_size = GetSplitSize(cb.dim(), split_tf); |
| const uint32_t split_w = BlockWidthPix(split_size); |
| const uint32_t split_h = BlockHeightPix(split_size); |
| |
| for (uint32_t y = 0; y < split_h; ++y) { |
| std::fill(output + y * step, output + y * step + split_w, 0); |
| } |
| } |
| |
| std::string GetName() const override { return "zero predictor"; } |
| std::string GetFakePredStr() const override { return "zero predictor"; } |
| std::string GetPredStr(const CodedBlockBase& cb, Channel channel, |
| bool split_tf, uint32_t tf_i) const override { |
| return "zero predictor"; |
| } |
| }; |
| |
| //------------------------------------------------------------------------------ |
| |
| void ContextPredictor::Predict(const CodedBlockBase& cb, Channel channel, |
| bool split_tf, uint32_t tf_i, int16_t output[], |
| uint32_t step) const { |
| if (tf_i > 0) assert(split_tf); |
| const int16_t* const context = |
| cb.GetContext(channel, split_tf, tf_i, context_type_); |
| const BlockSize split_size = GetSplitSize(cb.dim(), split_tf); |
| const uint32_t split_w = BlockWidthPix(split_size); |
| const uint32_t split_h = BlockHeightPix(split_size); |
| Predict(context, split_w, split_h, output, step); |
| } |
| |
| std::string ContextPredictor::GetFakePredStr() const { |
| return GetPredStr(kFakeContext, kFakeContext, kFakeContext, kContextSmall, |
| kPredWidth, kPredHeight); |
| } |
| |
| std::string ContextPredictor::GetPredStr(const CodedBlockBase& cb, |
| Channel channel, bool split_tf, |
| uint32_t tf_i) const { |
| const int16_t* const context = |
| cb.GetContext(channel, split_tf, tf_i, kContextSmall); |
| // Always display the top right context. |
| const int16_t* const context_right = |
| cb.GetContext(channel, split_tf, tf_i, kContextExtendRight); |
| const int16_t* const context_left = |
| cb.GetContext(channel, split_tf, tf_i, kContextExtendLeft); |
| const BlockSize split_size = GetSplitSize(cb.dim(), split_tf); |
| |
| return GetPredStr(context, context_right, context_left, context_type_, |
| BlockWidthPix(split_size), BlockHeightPix(split_size)); |
| } |
| |
| std::string ContextPredictor::GetPredStr(const int16_t context[], |
| const int16_t context_right[], |
| const int16_t context_left[], |
| ContextType context_type, |
| uint32_t width, |
| uint32_t height) const { |
| int16_t output[kMaxBlockSizePix2]; |
| Predict((context_type == kContextExtendRight) ? context_right |
| : (context_type == kContextExtendLeft) ? context_left |
| : context, |
| width, height, output, /*step=*/width); |
| return GetContextAndBlockPixelsStr(context, context_right, context_left, |
| width, height, output, /*step=*/width); |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| // The DC predictor sets all pixels to the average of top and/or left context. |
| class DCPredictor : public ContextPredictor { |
| public: |
| enum Type { kAll, kLeft, kTop }; |
| explicit DCPredictor(Type type) : type_(type) {} |
| |
| std::string GetName() const override { |
| const char* const kTypeStr[]{"DC predictor (avg of all context)", |
| "DC predictor (avg of left context)", |
| "DC predictor (avg of top context)"}; |
| return kTypeStr[type_]; |
| } |
| |
| protected: |
| void Predict(const int16_t context[], uint32_t w, uint32_t h, |
| int16_t output[], uint32_t step) const override { |
| // DC predictors will never predict values outside of the input range so we |
| // don't need to pass in accurate min/max bounds. |
| constexpr int16_t kMin = std::numeric_limits<int16_t>::min(); |
| constexpr int16_t kMax = std::numeric_limits<int16_t>::max(); |
| BasePredictors[BPRED_DC + type_](context, w, h, kMin, kMax, output, step); |
| } |
| |
| const Type type_; |
| }; |
| |
| //------------------------------------------------------------------------------ |
| |
| // The MedianDC predictor sets all pixels to the median of top and/or left |
| // context. |
| class MedianDCPredictor : public ContextPredictor { |
| public: |
| enum Type { kAll, kLeft, kTop }; |
| explicit MedianDCPredictor(Type type) : type_(type) {} |
| |
| std::string GetName() const override { |
| const char* const kTypeStr[]{"MedianDC predictor (median of all context)", |
| "MedianDC predictor (median of left context)", |
| "MedianDC predictor (median of top context)"}; |
| return kTypeStr[type_]; |
| } |
| |
| protected: |
| void Predict(const int16_t context[], uint32_t w, uint32_t h, |
| int16_t output[], uint32_t step) const override { |
| // MedianDC predictors will never predict values outside of the input range |
| // so we don't need to pass in accurate min/max bounds. |
| constexpr int16_t kMin = std::numeric_limits<int16_t>::min(); |
| constexpr int16_t kMax = std::numeric_limits<int16_t>::max(); |
| BasePredictors[BPRED_MEDIAN_DC + type_](context, w, h, kMin, kMax, output, |
| step); |
| } |
| |
| const Type type_; |
| }; |
| |
| //------------------------------------------------------------------------------ |
| |
| // TrueMotion predictor. |
| class TMPredictor : public ContextPredictor { |
| public: |
| TMPredictor(int16_t min_value, int16_t max_value) |
| : min_value_(min_value), max_value_(max_value) {} |
| |
| std::string GetName() const override { return "TM predictor"; } |
| |
| protected: |
| void Predict(const int16_t context[], uint32_t w, uint32_t h, |
| int16_t output[], uint32_t step) const override { |
| BasePredictors[BPRED_TM](context, w, h, min_value_, max_value_, output, |
| step); |
| } |
| |
| private: |
| const int16_t min_value_; |
| const int16_t max_value_; |
| }; |
| |
| //------------------------------------------------------------------------------ |
| |
| // Smooth predictor that interpolates pixels horizontally, vertically or a mix |
| // of both. |
| class SmoothPredictor : public ContextPredictor { |
| public: |
| enum class SmoothType { |
| k2DSmooth, |
| kVerticalSmooth, |
| kHorizontalSmooth, |
| kNumSmoothType |
| }; |
| explicit SmoothPredictor(SmoothType type) : type_(type) {} |
| |
| std::string GetName() const override { |
| const char* const kTypeStr[]{"2D", "vertical", "horizontal"}; |
| std::string name = "smooth predictor ("; |
| name += kTypeStr[(int)type_]; |
| name += ")"; |
| return name; |
| } |
| |
| protected: |
| void Predict(const int16_t context[], uint32_t w, uint32_t h, |
| int16_t output[], uint32_t step) const override { |
| // Smooth predictors will never predict values outside of the input range so |
| // we don't need to pass in accurate min/max bounds. |
| constexpr int16_t kMin = std::numeric_limits<int16_t>::min(); |
| constexpr int16_t kMax = std::numeric_limits<int16_t>::max(); |
| const BasePredictor p = (type_ == SmoothType::k2DSmooth) |
| ? BPRED_SMOOTH |
| : (type_ == SmoothType::kVerticalSmooth) |
| ? BPRED_SMOOTH_V |
| : BPRED_SMOOTH_H; |
| BasePredictors[p](context, w, h, kMin, kMax, output, step); |
| } |
| |
| const SmoothType type_; |
| }; |
| |
| //------------------------------------------------------------------------------ |
| |
| static uint32_t GetMaxAngleDelta(Channel channel) { |
| return (channel == kUChannel || channel == kVChannel) |
| ? kDirectionalMaxAngleDeltaUV |
| : kDirectionalMaxAngleDeltaYA; |
| } |
| |
| // Angle predictor. |
| class AnglePredictor : public ContextPredictor { |
| public: |
| // Use AV1 convention to name angle predictors. |
| enum class Type { |
| // Angles must be in the same order as in Predictors::Pred |
| D23_PRED, |
| D45_PRED, |
| D67_PRED, |
| V_PRED, // 90 degrees |
| D113_PRED, |
| D135_PRED, |
| D157_PRED, |
| H_PRED, // 180 degrees |
| D203_PRED, |
| D225_PRED, |
| Num |
| }; |
| AnglePredictor(Type type, Channel channel, uint32_t sub_mode, |
| int16_t angle_delta, int16_t min_value, int16_t max_value) |
| : min_value_(min_value), |
| max_value_(max_value), |
| sub_mode_(sub_mode), |
| angle_idx_(AngleIdx(type, channel, angle_delta)) { |
| context_type_ = GetContextType(angle_idx_); |
| static_assert((uint32_t)Type::Num == kAnglePredNum, |
| "Incorrect kAnglePredNum"); |
| } |
| |
| // Returns the angle index for precalculated angles. |
| // Index goes from 0 (12.86 degrees) to 69 (234.64 degrees). |
| static uint8_t AngleIdx(Type type, Channel channel, int16_t angle_delta) { |
| constexpr int16_t max_delta = kDirectionalMaxAngleDeltaYA; |
| if (channel == kUChannel || channel == kVChannel) { |
| // Note that the U/V prediction angles are not exactly even spaced. |
| // But they map to already existing angles for Y/A. |
| // TODO(skal): should the YA / UV angles be made independent? |
| angle_delta *= 2; |
| } |
| const int idx = (int)type * (2 * max_delta + 1) + max_delta + angle_delta; |
| assert(idx >= 0 && idx < (int)kNumDirectionalAngles); |
| return (uint8_t)idx; |
| } |
| |
| bool IsAngle(float* const angle) const override { |
| if (angle != nullptr) *angle = 22.5f + 22.5f * ((int)angle_idx_ - 3) / 7.f; |
| return true; |
| } |
| |
| uint32_t WriteParams(const CodedBlockBase& cb, Channel channel, |
| SymbolManager* const sm, |
| ANSEncBase* const enc) const override { |
| if (channel == kVChannel) return 0; |
| (void)cb; |
| (void)sm; |
| // By storing the angle step, we are then able, at decoding time, to figure |
| // out the sub-mode of an angle predictor. |
| const uint32_t max_delta = GetMaxAngleDelta(channel); |
| enc->PutRange(sub_mode_, 0, 2 * max_delta, "sub_mode"); |
| return sub_mode_; |
| } |
| |
| uint32_t ReadParams(CodedBlockBase* const cb, Channel channel, |
| SymbolReader* const sm, |
| ANSDec* const dec) const override { |
| if (channel == kVChannel) return 0; |
| (void)sm; |
| // TODO(vrabaud) explore using a symbol for the step, like in AV1. |
| const uint32_t max_delta = GetMaxAngleDelta(channel); |
| return dec->ReadRange(0, 2 * max_delta, "sub_mode"); |
| } |
| |
| std::string GetName() const override { |
| float angle; |
| if (!IsAngle(&angle)) assert(false); |
| const int angle_deg = std::round(angle); |
| return "angle predictor (" + std::to_string(angle_deg) + |
| " degrees, idx=" + std::to_string((int)angle_idx_) + ")"; |
| } |
| |
| protected: |
| void Predict(const int16_t context[], uint32_t w, uint32_t h, |
| int16_t output[], uint32_t step) const override { |
| const uint32_t context_size = ContextSize(context_type_, w, h); |
| for (uint32_t i = 0; i < context_size; ++i) { |
| assert(context[i] >= min_value_ && context[i] <= max_value_); |
| } |
| // SimpleAnglePredictor is faster and currently gives better results than |
| // BaseAnglePredictor. |
| const uint32_t w_idx = TrfLog2[w]; |
| const uint32_t h_idx = TrfLog2[h]; |
| SimpleAnglePredictor(angle_idx_, context, w_idx, h_idx, output, step); |
| } |
| const int16_t min_value_; |
| const int16_t max_value_; |
| |
| private: |
| const uint32_t sub_mode_; // used by encoder during WriteParams() |
| uint32_t angle_idx_; // global angle index, as calculated by AngleIdx() |
| }; |
| |
| //------------------------------------------------------------------------------ |
| |
| // Fuse predictor. |
| class FusePredictor : public ContextPredictor { |
| public: |
| FusePredictor(float strength, int16_t min_value, int16_t max_value) |
| : strength_(strength), min_value_(min_value), max_value_(max_value) { |
| PrecomputeLargeWeightTable(strength_, table_); |
| } |
| |
| std::string GetName() const override { |
| std::string name = "fuse predictor (strength: "; |
| name += std::to_string(strength_); |
| name += ")"; |
| return name; |
| } |
| |
| protected: |
| void Predict(const int16_t context[], uint32_t w, uint32_t h, |
| int16_t output[], uint32_t step) const override { |
| BaseFusePredictor(table_, context, w, h, output, step, min_value_, |
| max_value_); |
| } |
| |
| const float strength_; |
| LargeWeightTable table_; |
| const int16_t min_value_; |
| const int16_t max_value_; |
| }; |
| |
| //------------------------------------------------------------------------------ |
| |
| // Extrapolates a gradient from the bottom-left, top-left and top-right corners. |
| // This is sometimes more powerful than other predictors because it can guess |
| // values outside the ones in the context. |
| class GradientPredictor : public ContextPredictor { |
| public: |
| GradientPredictor(int16_t min_value, int16_t max_value) |
| : min_value_(min_value), max_value_(max_value) {} |
| |
| std::string GetName() const override { return "gradient predictor"; } |
| |
| protected: |
| void Predict(const int16_t c[], uint32_t w, uint32_t h, int16_t output[], |
| uint32_t step) const override { |
| // Average the three known (or extrapolated) corners. |
| const int32_t bottom_left = DivRound(c[0] + c[1] + c[2], 3); |
| const int32_t top_left = |
| DivRound(c[h - 2] + c[h - 1] + c[h] + c[h + 1] + c[h + 2], 5); |
| const int32_t top_right = |
| DivRound(c[h + w - 1] + c[h + w] + c[h + w + 1], 3); |
| // Extrapolate the middle value and from it, the bottom-right corner. |
| const int32_t middle = DivRound(bottom_left + top_right, 2); |
| const int32_t bottom_right = |
| Clamp(middle - top_left + middle, min_value_, max_value_); |
| |
| // Create a gradient by bidimensional interpolation. |
| const int32_t max_x = (int32_t)w - 1, max_y = (int32_t)h - 1; |
| for (int32_t y = 0; y <= max_y; ++y) { |
| const int32_t left = |
| DivRound(top_left * (max_y - y) + bottom_left * y, max_y); |
| const int32_t right = |
| DivRound(top_right * (max_y - y) + bottom_right * y, max_y); |
| for (int32_t x = 0; x <= max_x; ++x) { |
| output[x] = DivRound(left * (max_x - x) + right * x, max_x); |
| assert(output[x] >= min_value_ && output[x] <= max_value_); |
| } |
| output += step; |
| } |
| } |
| |
| const int32_t min_value_; |
| const int32_t max_value_; |
| }; |
| |
| //------------------------------------------------------------------------------ |
| // Alpha predictor. |
| |
| void AdjustAlphaPrediction(const CodedBlockBase& cb, CSPTransform transform, |
| int16_t output[], uint32_t step); |
| |
| // Predictor for the alpha plane, that wraps a normal predictor, and adjusts |
| // its prediction to take into account the fact RGB is premultiplied, therefore |
| // alpha >= max(r, g, b) |
| // This predictor assumes that YUV is encoded/reconstructed before alpha. |
| // The supplied Predictor is not owned, so we don't delete it. |
| class AlphaPredictor : public Predictor { |
| public: |
| // Takes ownership of 'pred'. |
| AlphaPredictor(Predictor* const pred, const CSPTransform& transform) |
| : pred_(pred), transform_(transform) {} |
| |
| ~AlphaPredictor() override { delete pred_; } |
| |
| void Predict(const CodedBlockBase& cb, Channel channel, bool split_tf, |
| uint32_t tf_i, int16_t output[], uint32_t step) const override { |
| assert(!split_tf); |
| assert(channel == kAChannel); // Should only be used for alpha. |
| pred_->Predict(cb, channel, /*split_tf=*/false, /*tf_i=*/0, output, step); |
| AdjustAlphaPrediction(cb, transform_, output, step); |
| } |
| |
| std::string GetName() const override { |
| return "alpha predictor based on " + pred_->GetName(); |
| } |
| std::string GetFakePredStr() const override { |
| printf("alpha predictor based on:\n"); |
| return pred_->GetFakePredStr(); |
| } |
| std::string GetPredStr(const CodedBlockBase& cb, Channel channel, |
| bool split_tf, uint32_t tf_i) const override { |
| assert(!split_tf); |
| std::string str = pred_->GetPredStr(cb, channel, split_tf, tf_i); |
| str += "==> adjusted by alpha predictor:\n"; |
| int16_t output[kMaxBlockSizePix2]; |
| Predict(cb, channel, /*split_tf=*/false, /*tf_i=*/0, output, cb.w_pix()); |
| const int16_t* const context = cb.GetContext(channel, kContextSmall); |
| str += GetContextAndBlockPixelsStr(context, /*context_right=*/nullptr, |
| /*context_left=*/nullptr, cb.w_pix(), |
| cb.h_pix(), output, cb.w_pix()); |
| return str; |
| } |
| |
| uint32_t mode() const override { return pred_->mode(); } |
| bool ComputeParams(CodedBlock* const cb, Channel channel) const override { |
| return pred_->ComputeParams(cb, channel); |
| } |
| uint32_t WriteParams(const CodedBlockBase& cb, Channel channel, |
| SymbolManager* const sm, |
| ANSEncBase* const enc) const override { |
| return pred_->WriteParams(cb, channel, sm, enc); |
| } |
| uint32_t ReadParams( |
| CodedBlockBase* const cb, Channel channel, |
| SymbolReader* const sm, ANSDec* const dec) const override { |
| return pred_->ReadParams(cb, channel, sm, dec); |
| } |
| bool DependsOnLuma() const override { return true; } |
| bool IsAngle(float* const angle) const override { |
| return pred_->IsAngle(angle); |
| } |
| |
| protected: |
| Predictor* const pred_; |
| const CSPTransform& transform_; |
| }; |
| |
| //------------------------------------------------------------------------------ |
| |
| uint32_t Predictors::GetMaxMode() const { |
| return preds_no_angle_.size() + kAnglePredNum - 1; |
| } |
| |
| const Predictor* Predictors::GetPred(uint32_t mode, uint32_t sub_mode) const { |
| const uint32_t index = main_modes_[mode]; |
| assert(index + sub_mode < preds_.size()); |
| assert(mode == GetMaxMode() || (index + sub_mode < main_modes_[mode + 1])); |
| assert(preds_[index]->mode() == preds_[index + sub_mode]->mode()); |
| return preds_[index + sub_mode]; |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| Predictor* Predictors::RecordPredictor( |
| Predictor* pred, uint32_t mode, const CSPTransform* const transform) { |
| if (pred == nullptr) return nullptr; |
| pred->SetMode(mode); |
| if (channel_ == kAChannel) { |
| AlphaPredictor* const alpha_pred = |
| new (WP2Allocable::nothrow) AlphaPredictor(pred, *transform); |
| if (alpha_pred == nullptr) { |
| delete pred; |
| return nullptr; |
| } |
| pred = alpha_pred; |
| } |
| if (!preds_.push_back(pred)) { |
| delete pred; |
| return nullptr; |
| } |
| return pred; |
| } |
| |
| Predictor* MakeNonAnglePredictor(const Predictors::Pred predictor_type, |
| int16_t min_value, int16_t max_value) { |
| switch (predictor_type) { |
| case Predictors::Pred::kDcAll: |
| return new (WP2Allocable::nothrow) DCPredictor(DCPredictor::kAll); |
| case Predictors::Pred::kDcLeft: |
| return new (WP2Allocable::nothrow) DCPredictor(DCPredictor::kLeft); |
| case Predictors::Pred::kDcTop: |
| return new (WP2Allocable::nothrow) DCPredictor(DCPredictor::kTop); |
| case Predictors::Pred::kMedianDcAll: |
| return new (WP2Allocable::nothrow) |
| MedianDCPredictor(MedianDCPredictor::kAll); |
| case Predictors::Pred::kMedianDcLeft: |
| return new (WP2Allocable::nothrow) |
| MedianDCPredictor(MedianDCPredictor::kLeft); |
| case Predictors::Pred::kMedianDcTop: |
| return new (WP2Allocable::nothrow) |
| MedianDCPredictor(MedianDCPredictor::kTop); |
| case Predictors::Pred::kSmooth2D: |
| return new (WP2Allocable::nothrow) |
| SmoothPredictor(SmoothPredictor::SmoothType::k2DSmooth); |
| case Predictors::Pred::kSmoothVertical: |
| return new (WP2Allocable::nothrow) |
| SmoothPredictor(SmoothPredictor::SmoothType::kVerticalSmooth); |
| case Predictors::Pred::kSmoothHorizontal: |
| return new (WP2Allocable::nothrow) |
| SmoothPredictor(SmoothPredictor::SmoothType::kHorizontalSmooth); |
| case Predictors::Pred::kTrueMotion: |
| return new (WP2Allocable::nothrow) TMPredictor(min_value, max_value); |
| case Predictors::Pred::kGradient: |
| return new (WP2Allocable::nothrow) |
| GradientPredictor(min_value, max_value); |
| case Predictors::Pred::kCfl: |
| return new (WP2Allocable::nothrow) CflPredictor(min_value, max_value); |
| case Predictors::Pred::kSignalingCfl: |
| return new (WP2Allocable::nothrow) |
| SignalingCflPredictor(min_value, max_value); |
| case Predictors::Pred::kZero: |
| return new (WP2Allocable::nothrow) ZeroPredictor(); |
| default: |
| // It will fail if you try to make an angle predictor with this function. |
| assert(false); |
| } |
| return nullptr; |
| } |
| |
| namespace { |
| bool IsAnglePredictor(const Predictors::Pred predictor_type) { |
| switch (predictor_type) { |
| case Predictors::Pred::kAngle23: |
| case Predictors::Pred::kAngle45: |
| case Predictors::Pred::kAngle67: |
| case Predictors::Pred::kAngle90: |
| case Predictors::Pred::kAngle113: |
| case Predictors::Pred::kAngle135: |
| case Predictors::Pred::kAngle157: |
| case Predictors::Pred::kAngle180: |
| case Predictors::Pred::kAngle203: |
| case Predictors::Pred::kAngle225: |
| return true; |
| default: |
| return false; |
| } |
| } |
| } // namespace |
| |
| WP2Status Predictors::FillImpl(const Pred* const mapping, uint32_t num_modes, |
| int16_t min_value, int16_t max_value, |
| const CSPTransform* const transform) { |
| const int max_delta = GetMaxAngleDelta(channel_); |
| WP2_CHECK_ALLOC_OK(preds_.reserve( |
| num_modes + (2 * max_delta) * (uint32_t)AnglePredictor::Type::Num)); |
| assert(num_modes >= (uint32_t)AnglePredictor::Type::Num); |
| // DC all should be first (to make it easy to force DC for all blocks, which |
| // is the same as using the first predictor). |
| assert(mapping[0] == Pred::kDcAll); |
| WP2_CHECK_ALLOC_OK( |
| preds_no_angle_.reserve(num_modes - (uint32_t)AnglePredictor::Type::Num)); |
| |
| uint32_t angle_type = 0; |
| uint32_t i = 0; |
| for (; i < num_modes; ++i) { |
| Pred p = mapping[i]; |
| WP2::Predictor* pred = nullptr; |
| |
| main_modes_[i] = preds_.size(); // for fast lookup in GetPred() |
| |
| if (IsAnglePredictor(p)) { |
| for (int32_t sub_mode = 0; sub_mode <= 2 * max_delta; ++sub_mode) { |
| // we map the sub-mode to a delta-angle around the main angle. |
| // It mostly affects the RD-opt search order, since the coding |
| // will be the same in the bitstream (we're using a Range). |
| const int32_t angle_step = |
| (sub_mode & 1) ? -((sub_mode + 1) >> 1) : (sub_mode >> 1); |
| // We assume they're in the same order in ContextPredictor and in |
| // AnglePredictor::Type |
| pred = new (WP2Allocable::nothrow) |
| AnglePredictor((AnglePredictor::Type)angle_type, channel_, sub_mode, |
| angle_step, min_value, max_value); |
| pred = RecordPredictor(pred, i, transform); |
| WP2_CHECK_ALLOC_OK(pred != nullptr); |
| // Store the main angle. |
| if (angle_step == 0) preds_main_angle_[angle_type] = pred; |
| } |
| ++angle_type; |
| } else { |
| pred = MakeNonAnglePredictor(p, min_value, max_value); |
| pred = RecordPredictor(pred, i, transform); |
| WP2_CHECK_ALLOC_OK(pred != nullptr); |
| WP2_CHECK_ALLOC_OK(preds_no_angle_.push_back(pred)); |
| } |
| } |
| assert(i <= GetMaxMode() + 1); |
| for (; i < ArraySize(main_modes_); ++i) { |
| main_modes_[i] = 0; // safety |
| } |
| |
| return WP2_STATUS_OK; |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| WP2Status YPredictors::Fill(int16_t min_value, int16_t max_value) { |
| constexpr Pred kYPreds[] = {Pred::kDcAll, |
| Pred::kDcLeft, |
| Pred::kDcTop, |
| Pred::kZero, |
| Pred::kSmooth2D, |
| Pred::kSmoothVertical, |
| Pred::kSmoothHorizontal, |
| Pred::kTrueMotion, |
| Pred::kGradient, |
| Pred::kAngle23, |
| Pred::kAngle45, |
| Pred::kAngle67, |
| Pred::kAngle90, |
| Pred::kAngle113, |
| Pred::kAngle135, |
| Pred::kAngle157, |
| Pred::kAngle180, |
| Pred::kAngle203, |
| Pred::kAngle225}; |
| |
| STATIC_ASSERT_ARRAY_SIZE(kYPreds, kYPredModeNum); |
| WP2_CHECK_STATUS(FillImpl(kYPreds, kYPredModeNum, min_value, max_value, |
| /*transform=*/nullptr)); |
| return WP2_STATUS_OK; |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| WP2Status APredictors::Fill(int16_t min_value, int16_t max_value, |
| const CSPTransform* const transform) { |
| constexpr Pred kAPreds[] = {Pred::kDcAll, Pred::kSmooth2D, |
| Pred::kSmoothVertical, Pred::kSmoothHorizontal, |
| Pred::kTrueMotion, Pred::kAngle23, |
| Pred::kAngle45, Pred::kAngle67, |
| Pred::kAngle90, Pred::kAngle113, |
| Pred::kAngle135, Pred::kAngle157, |
| Pred::kAngle180, Pred::kAngle203, |
| Pred::kAngle225, Pred::kCfl, |
| Pred::kSignalingCfl, Pred::kZero}; |
| |
| assert(transform != nullptr); |
| STATIC_ASSERT_ARRAY_SIZE(kAPreds, kAPredModeNum); |
| WP2_CHECK_STATUS( |
| FillImpl(kAPreds, kAPredModeNum, min_value, max_value, transform)); |
| return WP2_STATUS_OK; |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| WP2Status UVPredictors::Fill(int16_t min_value, int16_t max_value) { |
| constexpr Pred kUVPreds[] = { |
| Pred::kDcAll, Pred::kCfl, Pred::kSignalingCfl, Pred::kSmooth2D, |
| Pred::kAngle23, Pred::kAngle45, Pred::kAngle67, Pred::kAngle90, |
| Pred::kAngle113, Pred::kAngle135, Pred::kAngle157, Pred::kAngle180, |
| Pred::kAngle203, Pred::kAngle225}; |
| |
| STATIC_ASSERT_ARRAY_SIZE(kUVPreds, kUVPredModeNum); |
| WP2_CHECK_STATUS(FillImpl(kUVPreds, kUVPredModeNum, min_value, max_value, |
| /*transform=*/nullptr)); |
| return WP2_STATUS_OK; |
| } |
| |
| //------------------------------------------------------------------------------ |
| // Chroma-From-Luma |
| |
| CflPredictor::CflPredictor(int16_t min_value, int16_t max_value) |
| : min_value_(min_value), max_value_(max_value) {} |
| |
| void CflPredictor::LinearRegression(const int16_t* const luma, |
| const int16_t* const chroma, uint32_t size, |
| int32_t* const a, int32_t* const b) const { |
| int32_t num_values = 0; |
| int32_t l_sum = 0; |
| int32_t uv_sum = 0; |
| int32_t l_uv_sum = 0; |
| int32_t l_l_sum = 0; |
| for (uint32_t i = 0; i < size; ++i) { |
| if (luma[i] == CodedBlock::kMissing) continue; |
| ++num_values; |
| const int32_t l = luma[i], uv = chroma[i]; |
| l_sum += l; |
| uv_sum += uv; |
| l_uv_sum += l * uv; |
| l_l_sum += l * l; |
| } |
| if (num_values == 0) { |
| *a = *b = 0; |
| return; |
| } |
| const int64_t num = (int64_t)l_uv_sum * num_values - (int64_t)l_sum * uv_sum; |
| const int64_t den = (int64_t)l_l_sum * num_values - (int64_t)l_sum * l_sum; |
| int32_t A = 0; |
| int32_t B = DivRound(LeftShift(uv_sum, kCflFracBits), num_values); |
| if (num != 0 && den != 0) { |
| A = DivRound(LeftShift(num, kCflFracBits), den); // fits in 64b precision |
| B -= DivRound<int64_t>((int64_t)A * l_sum, num_values); |
| } |
| B += (1 << kCflFracBits >> 1); // include rounding constant |
| |
| // Tighten the scaling/offset value to avoid obvious overflows. |
| // Warning! alpha-from-chroma heavily relies on saturating to [0..255], |
| // and generate large 'B' constants. We must not clip 'B' with min/max_value_ |
| // but with the max codable range (kYuvMaxPrec = 9b) to retain this property. |
| const int32_t max_range = 2 << (kCflFracBits + kYuvMaxPrec); |
| if (std::abs(B) > max_range) { |
| A = 0; |
| B = DivRound(LeftShift(uv_sum, kCflFracBits), num_values); |
| } |
| *a = ClampToSigned(A, kCflABits); |
| *b = ClampToSigned(B, kCflBBits); |
| } |
| |
| void CflPredictor::ContextLinearRegression(Channel channel, |
| const CodedBlockBase& cb, |
| int32_t* const a, |
| int32_t* const b) const { |
| const int16_t* const luma_context = |
| cb.GetContext(kYChannel, kContextSmallNoFillIn); |
| const int16_t* const chroma_context = |
| cb.GetContext(channel, kContextSmallNoFillIn); |
| const uint32_t context_size = |
| ContextSize(kContextSmall, cb.w_pix(), cb.h_pix()); |
| LinearRegression(luma_context, chroma_context, context_size, a, b); |
| } |
| |
| void CflPredictor::BlockLinearRegression(Channel channel, const CodedBlock& cb, |
| int32_t* const a, |
| int32_t* const b) const { |
| const Plane16& luma_plane = cb.out_.GetChannel(kYChannel); |
| const Plane16& chroma_plane = cb.in_.GetChannel(channel); |
| int16_t luma[kMaxBlockSizePix2], chroma[kMaxBlockSizePix2]; |
| |
| for (uint32_t k = 0, y = 0; y < cb.h_pix(); ++y) { |
| const int16_t* const y_src = luma_plane.Row(y); |
| const int16_t* const c_src = chroma_plane.Row(y); |
| for (uint32_t x = 0; x < cb.w_pix(); ++x, ++k) { |
| luma[k] = y_src[x]; |
| chroma[k] = c_src[x]; |
| } |
| } |
| LinearRegression(luma, chroma, cb.h_pix() * cb.w_pix(), a, b); |
| } |
| |
| void CflPredictor::DoPredict(const CodedBlockBase& cb, Channel channel, |
| uint32_t x_start, uint32_t y_start, uint32_t w, |
| uint32_t h, int16_t* output, uint32_t step) const { |
| int32_t a, b; |
| ContextLinearRegression(channel, cb, &a, &b); |
| const Plane16& luma = cb.GetContextPlane(kYChannel); |
| for (uint32_t y = 0; y < h; ++y) { |
| const int16_t* const src = &luma.At(x_start, y_start + y); |
| CflPredict(src, (int16_t)a, b, output, min_value_, max_value_, w); |
| output += step; |
| } |
| } |
| |
| void CflPredictor::Predict(const CodedBlockBase& cb, Channel channel, |
| bool split_tf, uint32_t tf_i, int16_t output[], |
| uint32_t step) const { |
| assert(!split_tf); |
| assert(channel != kYChannel); |
| DoPredict(cb, channel, /*x_start=*/0, /*y_start*/ 0, cb.w_pix(), cb.h_pix(), |
| output, step); |
| } |
| |
| std::string CflPredictor::GetName() const { |
| return "chroma-from-luma predictor"; |
| } |
| |
| std::string CflPredictor::GetFakePredStr() const { |
| return "(fake prediction not supported)\n"; |
| } |
| |
| void CflPredictor::DisplaySamples(const CodedBlockBase& cb, Channel channel, |
| std::string* const str) const { |
| #if defined(WP2_BITTRACE) |
| const Plane16& luma = cb.GetContextPlane(kYChannel); |
| int16_t output[kMaxBlockSizePix2]; |
| Predict(cb, channel, /*split_tf=*/false, /*tf_i=*/0, output, |
| kMaxBlockSizePix); |
| |
| // Display subset of values to fit screen. |
| *str += WP2SPrint(" LUMA | chroma\n"); |
| for (uint32_t j = 0; j < kPredHeight; ++j) { |
| for (uint32_t i = 0; i < kPredWidth; ++i) { |
| *str += WP2SPrint("%4d ", luma.At(i, j)); |
| } |
| *str += " | "; |
| for (uint32_t i = 0; i < kPredWidth; ++i) { |
| *str += WP2SPrint("%4d ", output[i + j * kMaxBlockSizePix]); |
| } |
| *str += "\n"; |
| } |
| if (cb.w() > 1 || cb.h() > 1) *str += "(cropped)\n"; |
| #endif |
| } |
| |
| std::string CflPredictor::GetPredStr(const CodedBlockBase& cb, Channel channel, |
| bool split_tf, uint32_t tf_i) const { |
| assert(!split_tf); |
| int32_t a, b; |
| ContextLinearRegression(channel, cb, &a, &b); |
| std::string str = |
| WP2SPrint("Chroma = %.2f * Luma + %.2ff [(%d * Luma + %2d) >> %d]\n", |
| kCflNorm * a, kCflNorm * b, a, b, kCflFracBits); |
| DisplaySamples(cb, channel, &str); |
| return str; |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| SignalingCflPredictor::SignalingCflPredictor(int16_t min_value, |
| int16_t max_value) |
| : CflPredictor(min_value, max_value) {} |
| |
| void SignalingCflPredictor::GetParams(const CodedBlockBase& cb, Channel channel, |
| uint32_t w, uint32_t h, |
| int32_t* a, int32_t* b) const { |
| const Plane16& luma = cb.GetContextPlane(kYChannel); |
| int32_t predicted_a, predicted_b; |
| ContextLinearRegression(channel, cb, &predicted_a, &predicted_b); |
| const int32_t delta_a = |
| ChangePrecision(cb.cfl_[channel - 1], kIOFracBits, kCflFracBits); |
| |
| int32_t luma_sum = 0; |
| for (uint32_t y = 0; y < h; ++y) { |
| const int16_t* const y_src = luma.Row(y); |
| for (uint32_t x = 0; x < w; ++x) luma_sum += y_src[x]; |
| } |
| predicted_b -= DivRound<int64_t>((int64_t)delta_a * luma_sum, w * h); |
| predicted_a += delta_a; |
| *a = ClampToSigned(predicted_a, kCflABits); |
| *b = ClampToSigned(predicted_b, kCflBBits); |
| } |
| |
| void SignalingCflPredictor::DoPredict(const CodedBlockBase& cb, Channel channel, |
| uint32_t, uint32_t, uint32_t w, |
| uint32_t h, int16_t* output, |
| uint32_t step) const { |
| const Plane16& luma = cb.GetContextPlane(kYChannel); |
| int32_t a, b; |
| GetParams(cb, channel, w, h, &a, &b); |
| for (uint32_t y = 0; y < h; ++y) { |
| const int16_t* const y_src = luma.Row(y); |
| CflPredict(y_src, a, b, output, min_value_, max_value_, w); |
| output += step; |
| } |
| } |
| |
| bool SignalingCflPredictor::ComputeParams(CodedBlock* const cb, |
| Channel channel) const { |
| int32_t predicted_a, best_a, predicted_b, best_b; |
| ContextLinearRegression(channel, *cb, &predicted_a, &predicted_b); |
| BlockLinearRegression(channel, *cb, &best_a, &best_b); |
| best_a = ChangePrecision(best_a - predicted_a, kCflFracBits, kIOFracBits); |
| cb->cfl_[channel - 1] = ClampToSigned(best_a, kIOBits); |
| return (cb->cfl_[channel - 1] != 0); |
| } |
| |
| uint32_t SignalingCflPredictor::WriteParams(const CodedBlockBase& cb, |
| Channel channel, |
| SymbolManager* const sm, |
| ANSEncBase* const enc) const { |
| sm->Process(kSymbolCflSlope, cb.cfl_[channel - 1], "cfl_slope", enc); |
| return 0; |
| } |
| |
| uint32_t SignalingCflPredictor::ReadParams(CodedBlockBase* const cb, |
| Channel channel, |
| SymbolReader* const sr, |
| ANSDec* const dec) const { |
| cb->cfl_[channel - 1] = sr->Read(kSymbolCflSlope, "cfl_slope"); |
| return 0; |
| } |
| |
| std::string SignalingCflPredictor::GetName() const { |
| return "signaling chroma-from-luma"; |
| } |
| |
| std::string SignalingCflPredictor::GetPredStr(const CodedBlockBase& cb, |
| Channel channel, bool split_tf, |
| uint32_t tf_i) const { |
| assert(!split_tf); |
| int32_t predicted_a, a, predicted_b, b; |
| ContextLinearRegression(channel, cb, &predicted_a, &predicted_b); |
| GetParams(cb, channel, cb.w(), cb.h_pix(), &a, &b); |
| |
| std::string str; |
| str += WP2SPrint("Predicted a: %.2f, with correction: %.2f (param:%d)\n", |
| kCflNorm * predicted_a, kCflNorm * a, cb.cfl_[channel - 1]); |
| str += WP2SPrint("Predicted b: %.2f, with correction: %.2f\n", |
| kCflNorm * predicted_b, kCflNorm * b); |
| str += WP2SPrint("Chroma = %.2f * Luma + %.2ff [(%d * Luma + %2d) >> %d]\n", |
| kCflNorm * a, kCflNorm * b, a, b, kCflFracBits); |
| DisplaySamples(cb, channel, &str); |
| return str; |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| void AdjustAlphaPrediction(const CodedBlockBase& cb, CSPTransform transform, |
| int16_t output[], uint32_t step) { |
| for (uint16_t j = 0; j < cb.h_pix(); ++j) { |
| const int16_t* const y_src = cb.GetContextPlane(kYChannel).Row(j); |
| const int16_t* const u_src = cb.GetContextPlane(kUChannel).Row(j); |
| const int16_t* const v_src = cb.GetContextPlane(kVChannel).Row( j); |
| for (uint16_t i = 0; i < cb.w_pix(); ++i) { |
| int16_t r, g, b; |
| transform.YuvToRgb8(y_src[i], u_src[i], v_src[i], &r, &g, &b); |
| |
| // Alpha is larger than the largest R, G or B value. |
| // TODO(maryla): since this is based on reconstructed RGB, there might be |
| // some noise. Add a safety margin at lower quality? |
| const int16_t min_alpha_value = std::max({r, g, b, (int16_t)0}); |
| output[j * step + i] = std::max(output[j * step + i], min_alpha_value); |
| } |
| } |
| } |
| |
| // TODO(vrabaud) Remove |
| WP2Status InitAlphaPredictors(const CSPTransform& transform, |
| APredictors* const preds) { |
| const uint32_t kMinValue = 0; |
| const uint32_t kMaxValue = kAlphaMax; |
| |
| preds->reset(); |
| WP2_CHECK_STATUS(preds->Fill(kMinValue, kMaxValue, &transform)); |
| assert(preds->size() == kAPredNum); |
| |
| return WP2_STATUS_OK; |
| } |
| |
| //------------------------------------------------------------------------------ |
| |
| } // namespace WP2 |