Merge "Change to prediction decay calculation."
diff --git a/test/datarate_test.cc b/test/datarate_test.cc
index e66761f..c58adb6 100644
--- a/test/datarate_test.cc
+++ b/test/datarate_test.cc
@@ -1228,7 +1228,7 @@
 
 // Check basic rate targeting for 1 pass CBR SVC with denoising.
 // 2 spatial layers and 3 temporal layer. Run CIF clip with 1 thread.
-TEST_P(DatarateOnePassCbrSvc, OnePassCbrSvc2SpatialLayersDenoiserOn) {
+TEST_P(DatarateOnePassCbrSvc, DISABLED_OnePassCbrSvc2SpatialLayersDenoiserOn) {
   cfg_.rc_buf_initial_sz = 500;
   cfg_.rc_buf_optimal_sz = 500;
   cfg_.rc_buf_sz = 1000;
diff --git a/test/partial_idct_test.cc b/test/partial_idct_test.cc
index 1444b6e..3c78cb3 100644
--- a/test/partial_idct_test.cc
+++ b/test/partial_idct_test.cc
@@ -488,6 +488,15 @@
       &highbd_wrapper<vpx_highbd_idct16x16_38_add_neon>, TX_16X16, 38, 12, 2),
   make_tuple(
       &vpx_highbd_fdct16x16_c, &highbd_wrapper<vpx_highbd_idct16x16_256_add_c>,
+      &highbd_wrapper<vpx_highbd_idct16x16_10_add_neon>, TX_16X16, 10, 8, 2),
+  make_tuple(
+      &vpx_highbd_fdct16x16_c, &highbd_wrapper<vpx_highbd_idct16x16_256_add_c>,
+      &highbd_wrapper<vpx_highbd_idct16x16_10_add_neon>, TX_16X16, 10, 10, 2),
+  make_tuple(
+      &vpx_highbd_fdct16x16_c, &highbd_wrapper<vpx_highbd_idct16x16_256_add_c>,
+      &highbd_wrapper<vpx_highbd_idct16x16_10_add_neon>, TX_16X16, 10, 12, 2),
+  make_tuple(
+      &vpx_highbd_fdct16x16_c, &highbd_wrapper<vpx_highbd_idct16x16_256_add_c>,
       &highbd_wrapper<vpx_highbd_idct16x16_1_add_neon>, TX_16X16, 1, 8, 2),
   make_tuple(
       &vpx_highbd_fdct16x16_c, &highbd_wrapper<vpx_highbd_idct16x16_256_add_c>,
diff --git a/test/vp8_fdct4x4_test.cc b/test/vp8_fdct4x4_test.cc
index da4f0ca..4b3282d 100644
--- a/test/vp8_fdct4x4_test.cc
+++ b/test/vp8_fdct4x4_test.cc
@@ -17,12 +17,16 @@
 
 #include "third_party/googletest/src/include/gtest/gtest.h"
 
+#include "./vpx_config.h"
 #include "./vp8_rtcd.h"
 #include "test/acm_random.h"
 #include "vpx/vpx_integer.h"
+#include "vpx_ports/mem.h"
 
 namespace {
 
+typedef void (*FdctFunc)(int16_t *a, int16_t *b, int a_stride);
+
 const int cospi8sqrt2minus1 = 20091;
 const int sinpi8sqrt2 = 35468;
 
@@ -68,8 +72,19 @@
 
 using libvpx_test::ACMRandom;
 
-TEST(VP8FdctTest, SignBiasCheck) {
-  ACMRandom rnd(ACMRandom::DeterministicSeed());
+class FdctTest : public ::testing::TestWithParam<FdctFunc> {
+ public:
+  virtual void SetUp() {
+    fdct_func_ = GetParam();
+    rnd_.Reset(ACMRandom::DeterministicSeed());
+  }
+
+ protected:
+  FdctFunc fdct_func_;
+  ACMRandom rnd_;
+};
+
+TEST_P(FdctTest, SignBiasCheck) {
   int16_t test_input_block[16];
   int16_t test_output_block[16];
   const int pitch = 8;
@@ -81,10 +96,10 @@
   for (int i = 0; i < count_test_block; ++i) {
     // Initialize a test block with input range [-255, 255].
     for (int j = 0; j < 16; ++j) {
-      test_input_block[j] = rnd.Rand8() - rnd.Rand8();
+      test_input_block[j] = rnd_.Rand8() - rnd_.Rand8();
     }
 
-    vp8_short_fdct4x4_c(test_input_block, test_output_block, pitch);
+    fdct_func_(test_input_block, test_output_block, pitch);
 
     for (int j = 0; j < 16; ++j) {
       if (test_output_block[j] < 0) {
@@ -110,10 +125,10 @@
   for (int i = 0; i < count_test_block; ++i) {
     // Initialize a test block with input range [-15, 15].
     for (int j = 0; j < 16; ++j) {
-      test_input_block[j] = (rnd.Rand8() >> 4) - (rnd.Rand8() >> 4);
+      test_input_block[j] = (rnd_.Rand8() >> 4) - (rnd_.Rand8() >> 4);
     }
 
-    vp8_short_fdct4x4_c(test_input_block, test_output_block, pitch);
+    fdct_func_(test_input_block, test_output_block, pitch);
 
     for (int j = 0; j < 16; ++j) {
       if (test_output_block[j] < 0) {
@@ -135,23 +150,22 @@
       << "Error: 4x4 FDCT has a sign bias > 10% for input range [-15, 15]";
 };
 
-TEST(VP8FdctTest, RoundTripErrorCheck) {
-  ACMRandom rnd(ACMRandom::DeterministicSeed());
+TEST_P(FdctTest, RoundTripErrorCheck) {
   int max_error = 0;
   double total_error = 0;
   const int count_test_block = 1000000;
   for (int i = 0; i < count_test_block; ++i) {
-    int16_t test_input_block[16];
+    DECLARE_ALIGNED(16, int16_t, test_input_block[16]);
+    DECLARE_ALIGNED(16, int16_t, test_output_block[16]);
     int16_t test_temp_block[16];
-    int16_t test_output_block[16];
 
     // Initialize a test block with input range [-255, 255].
     for (int j = 0; j < 16; ++j) {
-      test_input_block[j] = rnd.Rand8() - rnd.Rand8();
+      test_input_block[j] = rnd_.Rand8() - rnd_.Rand8();
     }
 
     const int pitch = 8;
-    vp8_short_fdct4x4_c(test_input_block, test_temp_block, pitch);
+    fdct_func_(test_input_block, test_temp_block, pitch);
     reference_idct4x4(test_temp_block, test_output_block);
 
     for (int j = 0; j < 16; ++j) {
@@ -169,4 +183,20 @@
       << "Error: FDCT/IDCT has average roundtrip error > 1 per block";
 };
 
+INSTANTIATE_TEST_CASE_P(C, FdctTest, ::testing::Values(vp8_short_fdct4x4_c));
+
+#if HAVE_NEON
+INSTANTIATE_TEST_CASE_P(NEON, FdctTest,
+                        ::testing::Values(vp8_short_fdct4x4_neon));
+#endif  // HAVE_NEON
+
+#if HAVE_SSE2
+INSTANTIATE_TEST_CASE_P(SSE2, FdctTest,
+                        ::testing::Values(vp8_short_fdct4x4_sse2));
+#endif  // HAVE_SSE2
+
+#if HAVE_MSA
+INSTANTIATE_TEST_CASE_P(MSA, FdctTest,
+                        ::testing::Values(vp8_short_fdct4x4_msa));
+#endif  // HAVE_MSA
 }  // namespace
diff --git a/tools/tiny_ssim.c b/tools/tiny_ssim.c
index eb3253c..f7fee9d 100644
--- a/tools/tiny_ssim.c
+++ b/tools/tiny_ssim.c
@@ -117,7 +117,7 @@
 }
 
 int main(int argc, char *argv[]) {
-  FILE *f[2], *framestats = NULL;
+  FILE *f[2] = { NULL, NULL }, *framestats = NULL;
   uint8_t *buf[2];
   int w, h, tl_skip = 0, tl_skips_remaining = 0;
   double ssimavg = 0, ssimyavg = 0, ssimuavg = 0, ssimvavg = 0;
@@ -126,13 +126,15 @@
   double *ssimy = NULL, *ssimu = NULL, *ssimv = NULL;
   uint64_t *psnry = NULL, *psnru = NULL, *psnrv = NULL;
   size_t i, n_frames = 0, allocated_frames = 0;
+  int return_value = 0;
 
   if (argc < 4 || argc > 6) {
     fprintf(stderr,
             "Usage: %s file1.yuv file2.yuv WxH [tl_skip={0,1,3}] "
             "[framestats.csv]\n",
             argv[0]);
-    return 1;
+    return_value = 1;
+    goto clean_up;
   }
   f[0] = strcmp(argv[1], "-") ? fopen(argv[1], "rb") : stdin;
   f[1] = strcmp(argv[2], "-") ? fopen(argv[2], "rb") : stdin;
@@ -147,17 +149,20 @@
       if (!framestats) {
         fprintf(stderr, "Could not open \"%s\" for writing: %s\n", argv[5],
                 strerror(errno));
-        return 1;
+        return_value = 1;
+        goto clean_up;
       }
     }
   }
   if (!f[0] || !f[1]) {
     fprintf(stderr, "Could not open input files: %s\n", strerror(errno));
-    return 1;
+    return_value = 1;
+    goto clean_up;
   }
   if (w <= 0 || h <= 0 || w & 1 || h & 1) {
     fprintf(stderr, "Invalid size %dx%d\n", w, h);
-    return 1;
+    return_value = 1;
+    goto clean_up;
   }
   buf[0] = malloc(w * h * 3 / 2);
   buf[1] = malloc(w * h * 3 / 2);
@@ -177,7 +182,8 @@
     if (r1 && r2 && r1 != r2) {
       fprintf(stderr, "Failed to read data: %s [%d/%d]\n", strerror(errno),
               (int)r1, (int)r2);
-      return 1;
+      return_value = 1;
+      goto clean_up;
     } else if (r1 == 0 || r2 == 0) {
       break;
     }
@@ -278,8 +284,9 @@
 
   printf("Nframes: %d\n", (int)n_frames);
 
-  if (strcmp(argv[1], "-")) fclose(f[0]);
-  if (strcmp(argv[2], "-")) fclose(f[1]);
+clean_up:
+  if (f[0] && strcmp(argv[1], "-")) fclose(f[0]);
+  if (f[1] && strcmp(argv[2], "-")) fclose(f[1]);
   if (framestats) fclose(framestats);
 
   free(ssimy);
@@ -290,5 +297,5 @@
   free(psnru);
   free(psnrv);
 
-  return 0;
+  return return_value;
 }
diff --git a/vp9/encoder/vp9_pickmode.c b/vp9/encoder/vp9_pickmode.c
index 9f2e93a..00a552d 100644
--- a/vp9/encoder/vp9_pickmode.c
+++ b/vp9/encoder/vp9_pickmode.c
@@ -636,17 +636,19 @@
 #if CONFIG_VP9_HIGHBITDEPTH
   // TODO(jingning): Implement the high bit-depth Hadamard transforms and
   // remove this check condition.
-  // TODO(marpan): Disable this for 8 bit once optimizations for the functions
-  // below are merged in.
-  // if (xd->bd != 8) {
-  unsigned int var_y, sse_y;
-  (void)tx_size;
-  model_rd_for_sb_y(cpi, bsize, x, xd, &this_rdc->rate, &this_rdc->dist, &var_y,
-                    &sse_y);
-  *sse = INT_MAX;
-  *skippable = 0;
-  return;
-// }
+  // TODO(marpan): Use this path (model_rd) for 8bit under certain conditions
+  // for now, as the vp9_quantize_fp below for highbitdepth build is slow.
+  if (xd->bd != 8 ||
+      (cpi->oxcf.speed > 5 && cpi->common.frame_type != KEY_FRAME &&
+       bsize < BLOCK_32X32)) {
+    unsigned int var_y, sse_y;
+    (void)tx_size;
+    model_rd_for_sb_y(cpi, bsize, x, xd, &this_rdc->rate, &this_rdc->dist,
+                      &var_y, &sse_y);
+    *sse = INT_MAX;
+    *skippable = 0;
+    return;
+  }
 #endif
 
   (void)cpi;
diff --git a/vpx_dsp/arm/highbd_idct16x16_add_neon.c b/vpx_dsp/arm/highbd_idct16x16_add_neon.c
index 1a97bfd..d361c82 100644
--- a/vpx_dsp/arm/highbd_idct16x16_add_neon.c
+++ b/vpx_dsp/arm/highbd_idct16x16_add_neon.c
@@ -33,6 +33,19 @@
   d1->val[1] = vcombine_s32(t32[3].val[0], t32[3].val[1]);
 }
 
+static INLINE void highbd_idct16x16_add_wrap_low_4x2(const int64x2x2_t *const t,
+                                                     int32x4_t *const d0,
+                                                     int32x4_t *const d1) {
+  int32x2x2_t t32[2];
+
+  t32[0].val[0] = vrshrn_n_s64(t[0].val[0], DCT_CONST_BITS);
+  t32[0].val[1] = vrshrn_n_s64(t[0].val[1], DCT_CONST_BITS);
+  t32[1].val[0] = vrshrn_n_s64(t[1].val[0], DCT_CONST_BITS);
+  t32[1].val[1] = vrshrn_n_s64(t[1].val[1], DCT_CONST_BITS);
+  *d0 = vcombine_s32(t32[0].val[0], t32[0].val[1]);
+  *d1 = vcombine_s32(t32[1].val[0], t32[1].val[1]);
+}
+
 static INLINE int32x4x2_t
 highbd_idct16x16_add_wrap_low_8x1(const int64x2x2_t *const t) {
   int32x2x2_t t32[2];
@@ -47,6 +60,14 @@
   return d;
 }
 
+static INLINE int32x4_t highbd_idct16x16_add_wrap_low_4x1(const int64x2x2_t t) {
+  int32x2x2_t t32;
+
+  t32.val[0] = vrshrn_n_s64(t.val[0], DCT_CONST_BITS);
+  t32.val[1] = vrshrn_n_s64(t.val[1], DCT_CONST_BITS);
+  return vcombine_s32(t32.val[0], t32.val[1]);
+}
+
 static INLINE void highbd_idct_cospi_2_30(const int32x4x2_t s0,
                                           const int32x4x2_t s1,
                                           const int32x4_t cospi_2_30_10_22,
@@ -336,6 +357,27 @@
                                vget_low_s32(cospi_0_8_16_24), 1);
 }
 
+static INLINE void highbd_idct_cospi_8_24_d_kernel(
+    const int32x4_t s0, const int32x4_t s1, const int32x4_t cospi_0_8_16_24,
+    int64x2x2_t *const t) {
+  t[0].val[0] =
+      vmull_lane_s32(vget_low_s32(s0), vget_high_s32(cospi_0_8_16_24), 1);
+  t[0].val[1] =
+      vmull_lane_s32(vget_high_s32(s0), vget_high_s32(cospi_0_8_16_24), 1);
+  t[1].val[0] =
+      vmull_lane_s32(vget_low_s32(s1), vget_high_s32(cospi_0_8_16_24), 1);
+  t[1].val[1] =
+      vmull_lane_s32(vget_high_s32(s1), vget_high_s32(cospi_0_8_16_24), 1);
+  t[0].val[0] = vmlsl_lane_s32(t[0].val[0], vget_low_s32(s1),
+                               vget_low_s32(cospi_0_8_16_24), 1);
+  t[0].val[1] = vmlsl_lane_s32(t[0].val[1], vget_high_s32(s1),
+                               vget_low_s32(cospi_0_8_16_24), 1);
+  t[1].val[0] = vmlal_lane_s32(t[1].val[0], vget_low_s32(s0),
+                               vget_low_s32(cospi_0_8_16_24), 1);
+  t[1].val[1] = vmlal_lane_s32(t[1].val[1], vget_high_s32(s0),
+                               vget_low_s32(cospi_0_8_16_24), 1);
+}
+
 static INLINE void highbd_idct_cospi_8_24_q(const int32x4x2_t s0,
                                             const int32x4x2_t s1,
                                             const int32x4_t cospi_0_8_16_24,
@@ -347,6 +389,17 @@
   highbd_idct16x16_add_wrap_low_8x2(t, d0, d1);
 }
 
+static INLINE void highbd_idct_cospi_8_24_d(const int32x4_t s0,
+                                            const int32x4_t s1,
+                                            const int32x4_t cospi_0_8_16_24,
+                                            int32x4_t *const d0,
+                                            int32x4_t *const d1) {
+  int64x2x2_t t[2];
+
+  highbd_idct_cospi_8_24_d_kernel(s0, s1, cospi_0_8_16_24, t);
+  highbd_idct16x16_add_wrap_low_4x2(t, d0, d1);
+}
+
 static INLINE void highbd_idct_cospi_8_24_neg_q(const int32x4x2_t s0,
                                                 const int32x4x2_t s1,
                                                 const int32x4_t cospi_0_8_16_24,
@@ -362,6 +415,19 @@
   highbd_idct16x16_add_wrap_low_8x2(t, d0, d1);
 }
 
+static INLINE void highbd_idct_cospi_8_24_neg_d(const int32x4_t s0,
+                                                const int32x4_t s1,
+                                                const int32x4_t cospi_0_8_16_24,
+                                                int32x4_t *const d0,
+                                                int32x4_t *const d1) {
+  int64x2x2_t t[2];
+
+  highbd_idct_cospi_8_24_d_kernel(s0, s1, cospi_0_8_16_24, t);
+  t[1].val[0] = vsubq_s64(vdupq_n_s64(0), t[1].val[0]);
+  t[1].val[1] = vsubq_s64(vdupq_n_s64(0), t[1].val[1]);
+  highbd_idct16x16_add_wrap_low_4x2(t, d0, d1);
+}
+
 static INLINE void highbd_idct_cospi_16_16_q(const int32x4x2_t s0,
                                              const int32x4x2_t s1,
                                              const int32x4_t cospi_0_8_16_24,
@@ -396,8 +462,30 @@
   highbd_idct16x16_add_wrap_low_8x2(t, d0, d1);
 }
 
-static INLINE void highbd_idct16x16_add_stage7(const int32x4x2_t *const step2,
-                                               int32x4x2_t *const out) {
+static INLINE void highbd_idct_cospi_16_16_d(const int32x4_t s0,
+                                             const int32x4_t s1,
+                                             const int32x4_t cospi_0_8_16_24,
+                                             int32x4_t *const d0,
+                                             int32x4_t *const d1) {
+  int64x2x2_t t[3];
+
+  t[2].val[0] =
+      vmull_lane_s32(vget_low_s32(s1), vget_high_s32(cospi_0_8_16_24), 0);
+  t[2].val[1] =
+      vmull_lane_s32(vget_high_s32(s1), vget_high_s32(cospi_0_8_16_24), 0);
+  t[0].val[0] = vmlsl_lane_s32(t[2].val[0], vget_low_s32(s0),
+                               vget_high_s32(cospi_0_8_16_24), 0);
+  t[0].val[1] = vmlsl_lane_s32(t[2].val[1], vget_high_s32(s0),
+                               vget_high_s32(cospi_0_8_16_24), 0);
+  t[1].val[0] = vmlal_lane_s32(t[2].val[0], vget_low_s32(s0),
+                               vget_high_s32(cospi_0_8_16_24), 0);
+  t[1].val[1] = vmlal_lane_s32(t[2].val[1], vget_high_s32(s0),
+                               vget_high_s32(cospi_0_8_16_24), 0);
+  highbd_idct16x16_add_wrap_low_4x2(t, d0, d1);
+}
+
+static INLINE void highbd_idct16x16_add_stage7_dual(
+    const int32x4x2_t *const step2, int32x4x2_t *const out) {
   out[0].val[0] = vaddq_s32(step2[0].val[0], step2[15].val[0]);
   out[0].val[1] = vaddq_s32(step2[0].val[1], step2[15].val[1]);
   out[1].val[0] = vaddq_s32(step2[1].val[0], step2[14].val[0]);
@@ -432,6 +520,26 @@
   out[15].val[1] = vsubq_s32(step2[0].val[1], step2[15].val[1]);
 }
 
+static INLINE void highbd_idct16x16_add_stage7(const int32x4_t *const step2,
+                                               int32x4_t *const out) {
+  out[0] = vaddq_s32(step2[0], step2[15]);
+  out[1] = vaddq_s32(step2[1], step2[14]);
+  out[2] = vaddq_s32(step2[2], step2[13]);
+  out[3] = vaddq_s32(step2[3], step2[12]);
+  out[4] = vaddq_s32(step2[4], step2[11]);
+  out[5] = vaddq_s32(step2[5], step2[10]);
+  out[6] = vaddq_s32(step2[6], step2[9]);
+  out[7] = vaddq_s32(step2[7], step2[8]);
+  out[8] = vsubq_s32(step2[7], step2[8]);
+  out[9] = vsubq_s32(step2[6], step2[9]);
+  out[10] = vsubq_s32(step2[5], step2[10]);
+  out[11] = vsubq_s32(step2[4], step2[11]);
+  out[12] = vsubq_s32(step2[3], step2[12]);
+  out[13] = vsubq_s32(step2[2], step2[13]);
+  out[14] = vsubq_s32(step2[1], step2[14]);
+  out[15] = vsubq_s32(step2[0], step2[15]);
+}
+
 static INLINE void highbd_idct16x16_store_pass1(const int32x4x2_t *const out,
                                                 int32_t *output) {
   // Save the result into output
@@ -745,7 +853,7 @@
   step2[15] = step1[15];
 
   // stage 7
-  highbd_idct16x16_add_stage7(step2, out);
+  highbd_idct16x16_add_stage7_dual(step2, out);
 
   if (output) {
     highbd_idct16x16_store_pass1(out, output);
@@ -765,6 +873,15 @@
   return highbd_idct16x16_add_wrap_low_8x1(t);
 }
 
+static INLINE int32x4_t highbd_idct_cospi_lane0(const int32x4_t s,
+                                                const int32x2_t coef) {
+  int64x2x2_t t;
+
+  t.val[0] = vmull_lane_s32(vget_low_s32(s), coef, 0);
+  t.val[1] = vmull_lane_s32(vget_high_s32(s), coef, 0);
+  return highbd_idct16x16_add_wrap_low_4x1(t);
+}
+
 static INLINE int32x4x2_t highbd_idct_cospi_lane1_dual(const int32x4x2_t s,
                                                        const int32x2_t coef) {
   int64x2x2_t t[2];
@@ -776,6 +893,15 @@
   return highbd_idct16x16_add_wrap_low_8x1(t);
 }
 
+static INLINE int32x4_t highbd_idct_cospi_lane1(const int32x4_t s,
+                                                const int32x2_t coef) {
+  int64x2x2_t t;
+
+  t.val[0] = vmull_lane_s32(vget_low_s32(s), coef, 1);
+  t.val[1] = vmull_lane_s32(vget_high_s32(s), coef, 1);
+  return highbd_idct16x16_add_wrap_low_4x1(t);
+}
+
 static INLINE int32x4x2_t highbd_idct_add_dual(const int32x4x2_t s0,
                                                const int32x4x2_t s1) {
   int32x4x2_t t;
@@ -939,8 +1065,273 @@
   step2[15] = step1[15];
 
   // stage 7
+  highbd_idct16x16_add_stage7_dual(step2, out);
+
+  if (output) {
+    highbd_idct16x16_store_pass1(out, output);
+  } else {
+    highbd_idct16x16_add_store(out, dest, stride, bd);
+  }
+}
+
+void highbd_idct16x16_10_add_half1d_pass1(const tran_low_t *input,
+                                          int32_t *output) {
+  const int32x4_t cospi_0_8_16_24 = vld1q_s32(kCospi32 + 0);
+  const int32x4_t cospi_4_12_20N_28 = vld1q_s32(kCospi32 + 4);
+  const int32x4_t cospi_2_30_10_22 = vld1q_s32(kCospi32 + 8);
+  const int32x4_t cospi_6_26_14_18N = vld1q_s32(kCospi32 + 12);
+  int32x4_t in[4], step1[16], step2[16], out[16];
+
+  // Load input (4x4)
+  in[0] = vld1q_s32(input);
+  input += 16;
+  in[1] = vld1q_s32(input);
+  input += 16;
+  in[2] = vld1q_s32(input);
+  input += 16;
+  in[3] = vld1q_s32(input);
+
+  // Transpose
+  transpose_s32_4x4(&in[0], &in[1], &in[2], &in[3]);
+
+  // stage 1
+  step1[0] = in[0 / 2];
+  step1[4] = in[4 / 2];
+  step1[8] = in[2 / 2];
+  step1[12] = in[6 / 2];
+
+  // stage 2
+  step2[0] = step1[0];
+  step2[4] = step1[4];
+  step2[8] = highbd_idct_cospi_lane1(step1[8], vget_low_s32(cospi_2_30_10_22));
+  step2[11] =
+      highbd_idct_cospi_lane1(step1[12], vget_low_s32(cospi_6_26_14_18N));
+  step2[12] =
+      highbd_idct_cospi_lane0(step1[12], vget_low_s32(cospi_6_26_14_18N));
+  step2[15] = highbd_idct_cospi_lane0(step1[8], vget_low_s32(cospi_2_30_10_22));
+
+  // stage 3
+  step1[0] = step2[0];
+  step1[4] =
+      highbd_idct_cospi_lane1(step2[4], vget_high_s32(cospi_4_12_20N_28));
+  step1[7] = highbd_idct_cospi_lane0(step2[4], vget_low_s32(cospi_4_12_20N_28));
+  step1[8] = step2[8];
+  step1[9] = step2[8];
+  step1[10] = step2[11];
+  step1[11] = step2[11];
+  step1[12] = step2[12];
+  step1[13] = step2[12];
+  step1[14] = step2[15];
+  step1[15] = step2[15];
+
+  // stage 4
+  step2[0] = step2[1] =
+      highbd_idct_cospi_lane0(step1[0], vget_high_s32(cospi_0_8_16_24));
+  step2[4] = step1[4];
+  step2[5] = step1[4];
+  step2[6] = step1[7];
+  step2[7] = step1[7];
+  step2[8] = step1[8];
+  highbd_idct_cospi_8_24_d(step1[14], step1[9], cospi_0_8_16_24, &step2[9],
+                           &step2[14]);
+  highbd_idct_cospi_8_24_neg_d(step1[13], step1[10], cospi_0_8_16_24,
+                               &step2[13], &step2[10]);
+  step2[11] = step1[11];
+  step2[12] = step1[12];
+  step2[15] = step1[15];
+
+  // stage 5
+  step1[0] = step2[0];
+  step1[1] = step2[1];
+  step1[2] = step2[1];
+  step1[3] = step2[0];
+  step1[4] = step2[4];
+  highbd_idct_cospi_16_16_d(step2[5], step2[6], cospi_0_8_16_24, &step1[5],
+                            &step1[6]);
+  step1[7] = step2[7];
+  step1[8] = vaddq_s32(step2[8], step2[11]);
+  step1[9] = vaddq_s32(step2[9], step2[10]);
+  step1[10] = vsubq_s32(step2[9], step2[10]);
+  step1[11] = vsubq_s32(step2[8], step2[11]);
+  step1[12] = vsubq_s32(step2[15], step2[12]);
+  step1[13] = vsubq_s32(step2[14], step2[13]);
+  step1[14] = vaddq_s32(step2[14], step2[13]);
+  step1[15] = vaddq_s32(step2[15], step2[12]);
+
+  // stage 6
+  step2[0] = vaddq_s32(step1[0], step1[7]);
+  step2[1] = vaddq_s32(step1[1], step1[6]);
+  step2[2] = vaddq_s32(step1[2], step1[5]);
+  step2[3] = vaddq_s32(step1[3], step1[4]);
+  step2[4] = vsubq_s32(step1[3], step1[4]);
+  step2[5] = vsubq_s32(step1[2], step1[5]);
+  step2[6] = vsubq_s32(step1[1], step1[6]);
+  step2[7] = vsubq_s32(step1[0], step1[7]);
+  highbd_idct_cospi_16_16_d(step1[10], step1[13], cospi_0_8_16_24, &step2[10],
+                            &step2[13]);
+  highbd_idct_cospi_16_16_d(step1[11], step1[12], cospi_0_8_16_24, &step2[11],
+                            &step2[12]);
+  step2[8] = step1[8];
+  step2[9] = step1[9];
+  step2[14] = step1[14];
+  step2[15] = step1[15];
+
+  // stage 7
   highbd_idct16x16_add_stage7(step2, out);
 
+  // pass 1: save the result into output
+  vst1q_s32(output, out[0]);
+  output += 4;
+  vst1q_s32(output, out[1]);
+  output += 4;
+  vst1q_s32(output, out[2]);
+  output += 4;
+  vst1q_s32(output, out[3]);
+  output += 4;
+  vst1q_s32(output, out[4]);
+  output += 4;
+  vst1q_s32(output, out[5]);
+  output += 4;
+  vst1q_s32(output, out[6]);
+  output += 4;
+  vst1q_s32(output, out[7]);
+  output += 4;
+  vst1q_s32(output, out[8]);
+  output += 4;
+  vst1q_s32(output, out[9]);
+  output += 4;
+  vst1q_s32(output, out[10]);
+  output += 4;
+  vst1q_s32(output, out[11]);
+  output += 4;
+  vst1q_s32(output, out[12]);
+  output += 4;
+  vst1q_s32(output, out[13]);
+  output += 4;
+  vst1q_s32(output, out[14]);
+  output += 4;
+  vst1q_s32(output, out[15]);
+}
+
+void highbd_idct16x16_10_add_half1d_pass2(const int32_t *input,
+                                          int32_t *const output,
+                                          uint16_t *const dest,
+                                          const int stride, const int bd) {
+  const int32x4_t cospi_0_8_16_24 = vld1q_s32(kCospi32 + 0);
+  const int32x4_t cospi_4_12_20N_28 = vld1q_s32(kCospi32 + 4);
+  const int32x4_t cospi_2_30_10_22 = vld1q_s32(kCospi32 + 8);
+  const int32x4_t cospi_6_26_14_18N = vld1q_s32(kCospi32 + 12);
+  int32x4x2_t in[4], step1[16], step2[16], out[16];
+
+  // Load input (4x8)
+  in[0].val[0] = vld1q_s32(input);
+  input += 4;
+  in[0].val[1] = vld1q_s32(input);
+  input += 4;
+  in[1].val[0] = vld1q_s32(input);
+  input += 4;
+  in[1].val[1] = vld1q_s32(input);
+  input += 4;
+  in[2].val[0] = vld1q_s32(input);
+  input += 4;
+  in[2].val[1] = vld1q_s32(input);
+  input += 4;
+  in[3].val[0] = vld1q_s32(input);
+  input += 4;
+  in[3].val[1] = vld1q_s32(input);
+
+  // Transpose
+  transpose_s32_4x8(&in[0].val[0], &in[0].val[1], &in[1].val[0], &in[1].val[1],
+                    &in[2].val[0], &in[2].val[1], &in[3].val[0], &in[3].val[1]);
+
+  // stage 1
+  step1[0] = in[0 / 2];
+  step1[4] = in[4 / 2];
+  step1[8] = in[2 / 2];
+  step1[12] = in[6 / 2];
+
+  // stage 2
+  step2[0] = step1[0];
+  step2[4] = step1[4];
+  step2[8] =
+      highbd_idct_cospi_lane1_dual(step1[8], vget_low_s32(cospi_2_30_10_22));
+  step2[11] =
+      highbd_idct_cospi_lane1_dual(step1[12], vget_low_s32(cospi_6_26_14_18N));
+  step2[12] =
+      highbd_idct_cospi_lane0_dual(step1[12], vget_low_s32(cospi_6_26_14_18N));
+  step2[15] =
+      highbd_idct_cospi_lane0_dual(step1[8], vget_low_s32(cospi_2_30_10_22));
+
+  // stage 3
+  step1[0] = step2[0];
+  step1[4] =
+      highbd_idct_cospi_lane1_dual(step2[4], vget_high_s32(cospi_4_12_20N_28));
+  step1[7] =
+      highbd_idct_cospi_lane0_dual(step2[4], vget_low_s32(cospi_4_12_20N_28));
+  step1[8] = step2[8];
+  step1[9] = step2[8];
+  step1[10] = step2[11];
+  step1[11] = step2[11];
+  step1[12] = step2[12];
+  step1[13] = step2[12];
+  step1[14] = step2[15];
+  step1[15] = step2[15];
+
+  // stage 4
+  step2[0] = step2[1] =
+      highbd_idct_cospi_lane0_dual(step1[0], vget_high_s32(cospi_0_8_16_24));
+  step2[4] = step1[4];
+  step2[5] = step1[4];
+  step2[6] = step1[7];
+  step2[7] = step1[7];
+  step2[8] = step1[8];
+  highbd_idct_cospi_8_24_q(step1[14], step1[9], cospi_0_8_16_24, &step2[9],
+                           &step2[14]);
+  highbd_idct_cospi_8_24_neg_q(step1[13], step1[10], cospi_0_8_16_24,
+                               &step2[13], &step2[10]);
+  step2[11] = step1[11];
+  step2[12] = step1[12];
+  step2[15] = step1[15];
+
+  // stage 5
+  step1[0] = step2[0];
+  step1[1] = step2[1];
+  step1[2] = step2[1];
+  step1[3] = step2[0];
+  step1[4] = step2[4];
+  highbd_idct_cospi_16_16_q(step2[5], step2[6], cospi_0_8_16_24, &step1[5],
+                            &step1[6]);
+  step1[7] = step2[7];
+  step1[8] = highbd_idct_add_dual(step2[8], step2[11]);
+  step1[9] = highbd_idct_add_dual(step2[9], step2[10]);
+  step1[10] = highbd_idct_sub_dual(step2[9], step2[10]);
+  step1[11] = highbd_idct_sub_dual(step2[8], step2[11]);
+  step1[12] = highbd_idct_sub_dual(step2[15], step2[12]);
+  step1[13] = highbd_idct_sub_dual(step2[14], step2[13]);
+  step1[14] = highbd_idct_add_dual(step2[14], step2[13]);
+  step1[15] = highbd_idct_add_dual(step2[15], step2[12]);
+
+  // stage 6
+  step2[0] = highbd_idct_add_dual(step1[0], step1[7]);
+  step2[1] = highbd_idct_add_dual(step1[1], step1[6]);
+  step2[2] = highbd_idct_add_dual(step1[2], step1[5]);
+  step2[3] = highbd_idct_add_dual(step1[3], step1[4]);
+  step2[4] = highbd_idct_sub_dual(step1[3], step1[4]);
+  step2[5] = highbd_idct_sub_dual(step1[2], step1[5]);
+  step2[6] = highbd_idct_sub_dual(step1[1], step1[6]);
+  step2[7] = highbd_idct_sub_dual(step1[0], step1[7]);
+  highbd_idct_cospi_16_16_q(step1[10], step1[13], cospi_0_8_16_24, &step2[10],
+                            &step2[13]);
+  highbd_idct_cospi_16_16_q(step1[11], step1[12], cospi_0_8_16_24, &step2[11],
+                            &step2[12]);
+  step2[8] = step1[8];
+  step2[9] = step1[9];
+  step2[14] = step1[14];
+  step2[15] = step1[15];
+
+  // stage 7
+  highbd_idct16x16_add_stage7_dual(step2, out);
+
   if (output) {
     highbd_idct16x16_store_pass1(out, output);
   } else {
@@ -1026,6 +1417,42 @@
   }
 }
 
+void vpx_highbd_idct16x16_10_add_neon(const tran_low_t *input, uint8_t *dest8,
+                                      int stride, int bd) {
+  uint16_t *dest = CONVERT_TO_SHORTPTR(dest8);
+
+  if (bd == 8) {
+    int16_t row_idct_output[4 * 16];
+
+    // pass 1
+    // Parallel idct on the upper 8 rows
+    idct16x16_10_add_half1d_pass1(input, row_idct_output);
+
+    // pass 2
+    // Parallel idct to get the left 8 columns
+    idct16x16_10_add_half1d_pass2(row_idct_output, NULL, dest, stride, 1);
+
+    // Parallel idct to get the right 8 columns
+    idct16x16_10_add_half1d_pass2(row_idct_output + 4 * 8, NULL, dest + 8,
+                                  stride, 1);
+  } else {
+    int32_t row_idct_output[4 * 16];
+
+    // pass 1
+    // Parallel idct on the upper 8 rows
+    highbd_idct16x16_10_add_half1d_pass1(input, row_idct_output);
+
+    // pass 2
+    // Parallel idct to get the left 8 columns
+    highbd_idct16x16_10_add_half1d_pass2(row_idct_output, NULL, dest, stride,
+                                         bd);
+
+    // Parallel idct to get the right 8 columns
+    highbd_idct16x16_10_add_half1d_pass2(row_idct_output + 4 * 8, NULL,
+                                         dest + 8, stride, bd);
+  }
+}
+
 static INLINE void highbd_idct16x16_1_add_pos_kernel(uint16_t **dest,
                                                      const int stride,
                                                      const int16x8_t res,
diff --git a/vpx_dsp/arm/idct16x16_add_neon.c b/vpx_dsp/arm/idct16x16_add_neon.c
index 30bdb48..4259cd8 100644
--- a/vpx_dsp/arm/idct16x16_add_neon.c
+++ b/vpx_dsp/arm/idct16x16_add_neon.c
@@ -484,8 +484,7 @@
   }
 }
 
-static void idct16x16_10_add_half1d_pass1(const tran_low_t *input,
-                                          int16_t *output) {
+void idct16x16_10_add_half1d_pass1(const tran_low_t *input, int16_t *output) {
   const int16x8_t cospis0 = vld1q_s16(kCospi);
   const int16x8_t cospis1 = vld1q_s16(kCospi + 8);
   const int16x8_t cospisd0 = vaddq_s16(cospis0, cospis0);
@@ -638,8 +637,9 @@
   vst1_s16(output, out[15]);
 }
 
-static void idct16x16_10_add_half1d_pass2(const int16_t *input, int16_t *output,
-                                          uint8_t *dest, const int stride) {
+void idct16x16_10_add_half1d_pass2(const int16_t *input, int16_t *const output,
+                                   void *const dest, const int stride,
+                                   const int highbd_flag) {
   const int16x8_t cospis0 = vld1q_s16(kCospi);
   const int16x8_t cospis1 = vld1q_s16(kCospi + 8);
   const int16x8_t cospisd0 = vaddq_s16(cospis0, cospis0);
@@ -751,27 +751,16 @@
   step2[15] = step1[15];
 
   // stage 7
-  out[0] = vaddq_s16(step2[0], step2[15]);
-  out[1] = vaddq_s16(step2[1], step2[14]);
-  out[2] = vaddq_s16(step2[2], step2[13]);
-  out[3] = vaddq_s16(step2[3], step2[12]);
-  out[4] = vaddq_s16(step2[4], step2[11]);
-  out[5] = vaddq_s16(step2[5], step2[10]);
-  out[6] = vaddq_s16(step2[6], step2[9]);
-  out[7] = vaddq_s16(step2[7], step2[8]);
-  out[8] = vsubq_s16(step2[7], step2[8]);
-  out[9] = vsubq_s16(step2[6], step2[9]);
-  out[10] = vsubq_s16(step2[5], step2[10]);
-  out[11] = vsubq_s16(step2[4], step2[11]);
-  out[12] = vsubq_s16(step2[3], step2[12]);
-  out[13] = vsubq_s16(step2[2], step2[13]);
-  out[14] = vsubq_s16(step2[1], step2[14]);
-  out[15] = vsubq_s16(step2[0], step2[15]);
+  idct16x16_add_stage7(step2, out);
 
   if (output) {
     idct16x16_store_pass1(out, output);
   } else {
-    idct16x16_add_store(out, dest, stride);
+    if (highbd_flag) {
+      idct16x16_add_store_bd8(out, dest, stride);
+    } else {
+      idct16x16_add_store(out, dest, stride);
+    }
   }
 }
 
@@ -821,9 +810,9 @@
 
   // pass 2
   // Parallel idct to get the left 8 columns
-  idct16x16_10_add_half1d_pass2(row_idct_output, NULL, dest, stride);
+  idct16x16_10_add_half1d_pass2(row_idct_output, NULL, dest, stride, 0);
 
   // Parallel idct to get the right 8 columns
-  idct16x16_10_add_half1d_pass2(row_idct_output + 4 * 8, NULL, dest + 8,
-                                stride);
+  idct16x16_10_add_half1d_pass2(row_idct_output + 4 * 8, NULL, dest + 8, stride,
+                                0);
 }
diff --git a/vpx_dsp/arm/idct_neon.h b/vpx_dsp/arm/idct_neon.h
index 0a2052c..fe5b603 100644
--- a/vpx_dsp/arm/idct_neon.h
+++ b/vpx_dsp/arm/idct_neon.h
@@ -752,4 +752,10 @@
                              void *const dest, const int stride,
                              const int highbd_flag);
 
+void idct16x16_10_add_half1d_pass1(const tran_low_t *input, int16_t *output);
+
+void idct16x16_10_add_half1d_pass2(const int16_t *input, int16_t *const output,
+                                   void *const dest, const int stride,
+                                   const int highbd_flag);
+
 #endif  // VPX_DSP_ARM_IDCT_NEON_H_
diff --git a/vpx_dsp/vpx_dsp_rtcd_defs.pl b/vpx_dsp/vpx_dsp_rtcd_defs.pl
index 165def0..e9e7c6b 100644
--- a/vpx_dsp/vpx_dsp_rtcd_defs.pl
+++ b/vpx_dsp/vpx_dsp_rtcd_defs.pl
@@ -737,7 +737,7 @@
     $vpx_highbd_idct16x16_38_add_sse2=vpx_highbd_idct16x16_256_add_sse2;
 
     add_proto qw/void vpx_highbd_idct16x16_10_add/, "const tran_low_t *input, uint8_t *dest, int stride, int bd";
-    specialize qw/vpx_highbd_idct16x16_10_add sse2/;
+    specialize qw/vpx_highbd_idct16x16_10_add neon sse2/;
   }  # CONFIG_EMULATE_HARDWARE
 } else {
   # Force C versions if CONFIG_EMULATE_HARDWARE is 1