Merge "Add unit test for vp9_ext_ratectrl"
diff --git a/test/encode_test_driver.h b/test/encode_test_driver.h
index 3edba4b..38c6195 100644
--- a/test/encode_test_driver.h
+++ b/test/encode_test_driver.h
@@ -148,6 +148,13 @@
     ASSERT_EQ(VPX_CODEC_OK, res) << EncoderError();
   }
 
+#if CONFIG_VP9_ENCODER
+  void Control(int ctrl_id, vpx_rc_funcs_t *arg) {
+    const vpx_codec_err_t res = vpx_codec_control_(&encoder_, ctrl_id, arg);
+    ASSERT_EQ(VPX_CODEC_OK, res) << EncoderError();
+  }
+#endif  // CONFIG_VP9_ENCODER
+
 #if CONFIG_VP8_ENCODER || CONFIG_VP9_ENCODER
   void Control(int ctrl_id, vpx_active_map_t *arg) {
     const vpx_codec_err_t res = vpx_codec_control_(&encoder_, ctrl_id, arg);
diff --git a/test/test.mk b/test/test.mk
index c12fb78..0490238 100644
--- a/test/test.mk
+++ b/test/test.mk
@@ -58,6 +58,7 @@
 LIBVPX_TEST_SRCS-$(CONFIG_VP9_ENCODER) += svc_test.h
 LIBVPX_TEST_SRCS-$(CONFIG_VP9_ENCODER) += svc_end_to_end_test.cc
 LIBVPX_TEST_SRCS-$(CONFIG_VP9_ENCODER) += timestamp_test.cc
+LIBVPX_TEST_SRCS-$(CONFIG_VP9_ENCODER) += vp9_ext_ratectrl_test.cc
 
 LIBVPX_TEST_SRCS-yes                   += decode_test_driver.cc
 LIBVPX_TEST_SRCS-yes                   += decode_test_driver.h
diff --git a/test/vp9_ext_ratectrl_test.cc b/test/vp9_ext_ratectrl_test.cc
new file mode 100644
index 0000000..aa36093
--- /dev/null
+++ b/test/vp9_ext_ratectrl_test.cc
@@ -0,0 +1,156 @@
+/*
+ *  Copyright (c) 2020 The WebM project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "test/codec_factory.h"
+#include "test/encode_test_driver.h"
+#include "test/util.h"
+#include "test/yuv_video_source.h"
+#include "third_party/googletest/src/include/gtest/gtest.h"
+#include "vpx/vpx_ext_ratectrl.h"
+
+namespace {
+
+constexpr int kModelMagicNumber = 51396;
+constexpr unsigned int PrivMagicNumber = 5566UL;
+constexpr int kFrameNum = 5;
+constexpr int kLosslessCodingIndex = 2;
+
+struct ToyRateCtrl {
+  int magic_number;
+  int coding_index;
+};
+
+int rc_create_model(void *priv, const vpx_rc_config_t *ratectrl_config,
+                    vpx_rc_model_t *rate_ctrl_model_pt) {
+  ToyRateCtrl *toy_rate_ctrl = new (std::nothrow) ToyRateCtrl;
+  EXPECT_NE(toy_rate_ctrl, nullptr);
+  toy_rate_ctrl->magic_number = kModelMagicNumber;
+  toy_rate_ctrl->coding_index = -1;
+  *rate_ctrl_model_pt = (vpx_rc_model_t)toy_rate_ctrl;
+  EXPECT_EQ(priv, reinterpret_cast<void *>(PrivMagicNumber));
+  EXPECT_EQ(ratectrl_config->frame_width, 352);
+  EXPECT_EQ(ratectrl_config->frame_height, 288);
+  EXPECT_EQ(ratectrl_config->show_frame_count, kFrameNum);
+  EXPECT_EQ(ratectrl_config->target_bitrate_kbps, 24000);
+  EXPECT_EQ(ratectrl_config->frame_rate_num, 30);
+  EXPECT_EQ(ratectrl_config->frame_rate_den, 1);
+  return 0;
+}
+
+int rc_send_firstpass_stats(vpx_rc_model_t rate_ctrl_model,
+                            const vpx_rc_firstpass_stats_t *first_pass_stats) {
+  const ToyRateCtrl *toy_rate_ctrl =
+      static_cast<ToyRateCtrl *>(rate_ctrl_model);
+  EXPECT_EQ(toy_rate_ctrl->magic_number, kModelMagicNumber);
+  EXPECT_EQ(first_pass_stats->num_frames, kFrameNum);
+  for (int i = 0; i < first_pass_stats->num_frames; ++i) {
+    EXPECT_DOUBLE_EQ(first_pass_stats->frame_stats[i].frame, i);
+  }
+  return 0;
+}
+
+int rc_get_encodeframe_decision(
+    vpx_rc_model_t rate_ctrl_model,
+    const vpx_rc_encodeframe_info_t *encode_frame_info,
+    vpx_rc_encodeframe_decision_t *frame_decision) {
+  ToyRateCtrl *toy_rate_ctrl = static_cast<ToyRateCtrl *>(rate_ctrl_model);
+  toy_rate_ctrl->coding_index += 1;
+
+  EXPECT_EQ(toy_rate_ctrl->magic_number, kModelMagicNumber);
+
+  EXPECT_LT(encode_frame_info->show_index, kFrameNum);
+  EXPECT_EQ(encode_frame_info->coding_index, toy_rate_ctrl->coding_index);
+
+  if (encode_frame_info->coding_index == 0) {
+    EXPECT_EQ(encode_frame_info->frame_type, 0 /*kFrameTypeKey*/);
+  }
+
+  if (encode_frame_info->coding_index == 1) {
+    EXPECT_EQ(encode_frame_info->frame_type, 2 /*kFrameTypeAltRef*/);
+  }
+
+  if (encode_frame_info->coding_index >= 2 &&
+      encode_frame_info->coding_index < 5) {
+    EXPECT_EQ(encode_frame_info->frame_type, 1 /*kFrameTypeInter*/);
+  }
+
+  if (encode_frame_info->coding_index == 5) {
+    EXPECT_EQ(encode_frame_info->frame_type, 3 /*kFrameTypeOverlay*/);
+  }
+  if (encode_frame_info->coding_index == kLosslessCodingIndex) {
+    // We should get sse == 0 at rc_update_encodeframe_result()
+    frame_decision->q_index = 0;
+  } else {
+    frame_decision->q_index = 100;
+  }
+  return 0;
+}
+
+int rc_update_encodeframe_result(
+    vpx_rc_model_t rate_ctrl_model,
+    const vpx_rc_encodeframe_result_t *encode_frame_result) {
+  const ToyRateCtrl *toy_rate_ctrl =
+      static_cast<ToyRateCtrl *>(rate_ctrl_model);
+  EXPECT_EQ(toy_rate_ctrl->magic_number, kModelMagicNumber);
+
+  int64_t ref_pixel_count = 352 * 288 * 3 / 2;
+  EXPECT_EQ(encode_frame_result->pixel_count, ref_pixel_count);
+  if (toy_rate_ctrl->coding_index == kLosslessCodingIndex) {
+    EXPECT_EQ(encode_frame_result->sse, 0);
+  }
+  return 0;
+}
+
+int rc_delete_model(vpx_rc_model_t rate_ctrl_model) {
+  ToyRateCtrl *toy_rate_ctrl = static_cast<ToyRateCtrl *>(rate_ctrl_model);
+  EXPECT_EQ(toy_rate_ctrl->magic_number, kModelMagicNumber);
+  delete toy_rate_ctrl;
+  return 0;
+}
+
+class ExtRateCtrlTest : public ::libvpx_test::EncoderTest,
+                        public ::testing::Test {
+ protected:
+  ExtRateCtrlTest() : EncoderTest(&::libvpx_test::kVP9) {}
+
+  ~ExtRateCtrlTest() = default;
+
+  void SetUp() override {
+    InitializeConfig();
+    SetMode(::libvpx_test::kTwoPassGood);
+  }
+
+  void PreEncodeFrameHook(::libvpx_test::VideoSource *video,
+                          ::libvpx_test::Encoder *encoder) override {
+    if (video->frame() == 0) {
+      vpx_rc_funcs_t rc_funcs;
+      rc_funcs.create_model = rc_create_model;
+      rc_funcs.send_firstpass_stats = rc_send_firstpass_stats;
+      rc_funcs.get_encodeframe_decision = rc_get_encodeframe_decision;
+      rc_funcs.update_encodeframe_result = rc_update_encodeframe_result;
+      rc_funcs.delete_model = rc_delete_model;
+      rc_funcs.priv = reinterpret_cast<void *>(PrivMagicNumber);
+      encoder->Control(VP9E_SET_EXTERNAL_RATE_CONTROL, &rc_funcs);
+    }
+  }
+};
+
+TEST_F(ExtRateCtrlTest, EncodeTest) {
+  cfg_.rc_target_bitrate = 24000;
+
+  std::unique_ptr<libvpx_test::VideoSource> video;
+  video.reset(new libvpx_test::YUVVideoSource("bus_352x288_420_f20_b8.yuv",
+                                              VPX_IMG_FMT_I420, 352, 288, 30, 1,
+                                              0, kFrameNum));
+
+  ASSERT_NE(video.get(), nullptr);
+  ASSERT_NO_FATAL_FAILURE(RunLoop(video.get()));
+}
+}  // namespace
diff --git a/vp9/encoder/vp9_encoder.c b/vp9/encoder/vp9_encoder.c
index 6ffe41e..8d1d3b8 100644
--- a/vp9/encoder/vp9_encoder.c
+++ b/vp9/encoder/vp9_encoder.c
@@ -2538,8 +2538,6 @@
       num_frames = packets - 1;
       fps_init_first_pass_info(&cpi->twopass.first_pass_info,
                                oxcf->two_pass_stats_in.buf, num_frames);
-      vp9_extrc_send_firstpass_stats(&cpi->ext_ratectrl,
-                                     &cpi->twopass.first_pass_info);
 
       vp9_init_second_pass(cpi);
     }
@@ -5483,9 +5481,9 @@
   {
     const RefCntBuffer *coded_frame_buf =
         get_ref_cnt_buffer(cm, cm->new_fb_idx);
-    vp9_extrc_update_encodeframe_result(&cpi->ext_ratectrl, (*size) << 3,
-                                        cpi->Source, &coded_frame_buf->buf,
-                                        cpi->oxcf.input_bit_depth);
+    vp9_extrc_update_encodeframe_result(
+        &cpi->ext_ratectrl, (*size) << 3, cpi->Source, &coded_frame_buf->buf,
+        cm->bit_depth, cpi->oxcf.input_bit_depth);
   }
 #if CONFIG_REALTIME_ONLY
   (void)encode_frame_result;
@@ -5517,7 +5515,7 @@
         ref_frame_flags,
         cpi->twopass.gf_group.update_type[cpi->twopass.gf_group.index],
         cpi->Source, coded_frame_buf, ref_frame_bufs, vp9_get_quantizer(cpi),
-        cpi->oxcf.input_bit_depth, cm->bit_depth, cpi->td.counts,
+        cm->bit_depth, cpi->oxcf.input_bit_depth, cpi->td.counts,
 #if CONFIG_RATE_CTRL
         cpi->partition_info, cpi->motion_vector_info,
 #endif  // CONFIG_RATE_CTRL
@@ -5674,6 +5672,11 @@
                         unsigned int *frame_flags,
                         ENCODE_FRAME_RESULT *encode_frame_result) {
   cpi->allow_encode_breakout = ENCODE_BREAKOUT_ENABLED;
+
+  if (cpi->common.current_frame_coding_index == 0) {
+    vp9_extrc_send_firstpass_stats(&cpi->ext_ratectrl,
+                                   &cpi->twopass.first_pass_info);
+  }
 #if CONFIG_MISMATCH_DEBUG
   mismatch_move_frame_idx_w();
 #endif
diff --git a/vp9/encoder/vp9_ext_ratectrl.c b/vp9/encoder/vp9_ext_ratectrl.c
index 64414cd..ca75651 100644
--- a/vp9/encoder/vp9_ext_ratectrl.c
+++ b/vp9/encoder/vp9_ext_ratectrl.c
@@ -22,7 +22,7 @@
   ext_ratectrl->ratectrl_config = ratectrl_config;
   ext_ratectrl->funcs.create_model(ext_ratectrl->funcs.priv,
                                    &ext_ratectrl->ratectrl_config,
-                                   ext_ratectrl->model);
+                                   &ext_ratectrl->model);
   rc_firstpass_stats = &ext_ratectrl->rc_firstpass_stats;
   rc_firstpass_stats->num_frames = ratectrl_config.show_frame_count;
   rc_firstpass_stats->frame_stats =
@@ -34,6 +34,7 @@
 void vp9_extrc_delete(EXT_RATECTRL *ext_ratectrl) {
   if (ext_ratectrl->ready) {
     ext_ratectrl->funcs.delete_model(ext_ratectrl->model);
+    vpx_free(ext_ratectrl->rc_firstpass_stats.frame_stats);
   }
   vp9_extrc_init(ext_ratectrl);
 }
@@ -118,6 +119,7 @@
                                          int64_t bit_count,
                                          const YV12_BUFFER_CONFIG *source_frame,
                                          const YV12_BUFFER_CONFIG *coded_frame,
+                                         uint32_t bit_depth,
                                          uint32_t input_bit_depth) {
   if (ext_ratectrl->ready) {
     PSNR_STATS psnr;
@@ -127,9 +129,10 @@
         source_frame->y_width * source_frame->y_height +
         2 * source_frame->uv_width * source_frame->uv_height;
 #if CONFIG_VP9_HIGHBITDEPTH
-    vpx_calc_highbd_psnr(source_frame, coded_frame, &psnr,
-                         source_frame->bit_depth, input_bit_depth);
+    vpx_calc_highbd_psnr(source_frame, coded_frame, &psnr, bit_depth,
+                         input_bit_depth);
 #else
+    (void)bit_depth;
     (void)input_bit_depth;
     vpx_calc_psnr(source_frame, coded_frame, &psnr);
 #endif
diff --git a/vp9/encoder/vp9_ext_ratectrl.h b/vp9/encoder/vp9_ext_ratectrl.h
index 82f300c..fe8a66c 100644
--- a/vp9/encoder/vp9_ext_ratectrl.h
+++ b/vp9/encoder/vp9_ext_ratectrl.h
@@ -40,6 +40,7 @@
                                          int64_t bit_count,
                                          const YV12_BUFFER_CONFIG *source_frame,
                                          const YV12_BUFFER_CONFIG *coded_frame,
+                                         uint32_t bit_depth,
                                          uint32_t input_bit_depth);
 
 #endif  // VPX_VP9_ENCODER_VP9_EXT_RATECTRL_H_
diff --git a/vp9/vp9_cx_iface.c b/vp9/vp9_cx_iface.c
index aa13fc9..52da52a 100644
--- a/vp9/vp9_cx_iface.c
+++ b/vp9/vp9_cx_iface.c
@@ -1739,20 +1739,23 @@
   VP9_COMP *cpi = ctx->cpi;
   EXT_RATECTRL *ext_ratectrl = &cpi->ext_ratectrl;
   const VP9EncoderConfig *oxcf = &cpi->oxcf;
-  const FRAME_INFO *frame_info = &cpi->frame_info;
-  vpx_rc_config_t ratectrl_config;
+  // TODO(angiebird): Check the possibility of this flag being set at pass == 1
+  if (oxcf->pass == 2) {
+    const FRAME_INFO *frame_info = &cpi->frame_info;
+    vpx_rc_config_t ratectrl_config;
 
-  ratectrl_config.frame_width = frame_info->frame_width;
-  ratectrl_config.frame_height = frame_info->frame_height;
-  ratectrl_config.show_frame_count = cpi->twopass.first_pass_info.num_frames;
+    ratectrl_config.frame_width = frame_info->frame_width;
+    ratectrl_config.frame_height = frame_info->frame_height;
+    ratectrl_config.show_frame_count = cpi->twopass.first_pass_info.num_frames;
 
-  // TODO(angiebird): Double check whether this is the proper way to set up
-  // target_bitrate and frame_rate.
-  ratectrl_config.target_bitrate_kbps = (int)(oxcf->target_bandwidth / 1000);
-  ratectrl_config.frame_rate_num = oxcf->g_timebase.den;
-  ratectrl_config.frame_rate_den = oxcf->g_timebase.num;
+    // TODO(angiebird): Double check whether this is the proper way to set up
+    // target_bitrate and frame_rate.
+    ratectrl_config.target_bitrate_kbps = (int)(oxcf->target_bandwidth / 1000);
+    ratectrl_config.frame_rate_num = oxcf->g_timebase.den;
+    ratectrl_config.frame_rate_den = oxcf->g_timebase.num;
 
-  vp9_extrc_create(funcs, ratectrl_config, ext_ratectrl);
+    vp9_extrc_create(funcs, ratectrl_config, ext_ratectrl);
+  }
   return VPX_CODEC_OK;
 }