Add vp9_highbd_iht4x4_16_add_sse4_1()

BUG=webm:1413

Change-Id: I14930d0af24370a44ab359de5bba5512eef4e29f
diff --git a/test/dct_test.cc b/test/dct_test.cc
index 5b228ff..49b84f1 100644
--- a/test/dct_test.cc
+++ b/test/dct_test.cc
@@ -746,7 +746,9 @@
 
 INSTANTIATE_TEST_CASE_P(C, TransHT, ::testing::ValuesIn(c_ht_tests));
 
-#if HAVE_SSE2 && !CONFIG_EMULATE_HARDWARE
+#if !CONFIG_EMULATE_HARDWARE
+
+#if HAVE_SSE2
 INSTANTIATE_TEST_CASE_P(
     SSE2, TransHT,
     ::testing::Values(
@@ -776,7 +778,51 @@
                    VPX_BITS_8, 1),
         make_tuple(&vp9_fht4x4_sse2, &iht_wrapper<vp9_iht4x4_16_add_sse2>, 4, 3,
                    VPX_BITS_8, 1)));
-#endif  // HAVE_SSE2 && !CONFIG_EMULATE_HARDWARE
+#endif  // HAVE_SSE2
+
+#if HAVE_SSE4_1 && CONFIG_VP9_HIGHBITDEPTH
+INSTANTIATE_TEST_CASE_P(
+    SSE4_1, TransHT,
+    ::testing::Values(
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 0,
+                   VPX_BITS_8, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 1,
+                   VPX_BITS_8, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 2,
+                   VPX_BITS_8, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 3,
+                   VPX_BITS_8, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 0,
+                   VPX_BITS_10, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 1,
+                   VPX_BITS_10, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 2,
+                   VPX_BITS_10, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 3,
+                   VPX_BITS_10, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 0,
+                   VPX_BITS_12, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 1,
+                   VPX_BITS_12, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 2,
+                   VPX_BITS_12, 2),
+        make_tuple(&vp9_highbd_fht4x4_c,
+                   &highbd_iht_wrapper<vp9_highbd_iht4x4_16_add_sse4_1>, 4, 3,
+                   VPX_BITS_12, 2)));
+#endif  // HAVE_SSE4_1 && CONFIG_VP9_HIGHBITDEPTH
+
+#endif  // !CONFIG_EMULATE_HARDWARE
 
 /* -------------------------------------------------------------------------- */
 
diff --git a/vp9/common/vp9_rtcd_defs.pl b/vp9/common/vp9_rtcd_defs.pl
index 22b67ec..dd61202 100644
--- a/vp9/common/vp9_rtcd_defs.pl
+++ b/vp9/common/vp9_rtcd_defs.pl
@@ -97,6 +97,9 @@
   # Note as optimized versions of these functions are added we need to add a check to ensure
   # that when CONFIG_EMULATE_HARDWARE is on, it defaults to the C versions only.
   add_proto qw/void vp9_highbd_iht4x4_16_add/, "const tran_low_t *input, uint16_t *dest, int stride, int tx_type, int bd";
+  if (vpx_config("CONFIG_EMULATE_HARDWARE") ne "yes") {
+    specialize qw/vp9_highbd_iht4x4_16_add sse4_1/;
+  }
 
   add_proto qw/void vp9_highbd_iht8x8_64_add/, "const tran_low_t *input, uint16_t *dest, int stride, int tx_type, int bd";
 
diff --git a/vp9/common/x86/vp9_highbd_iht4x4_add_sse4.c b/vp9/common/x86/vp9_highbd_iht4x4_add_sse4.c
new file mode 100644
index 0000000..af15853
--- /dev/null
+++ b/vp9/common/x86/vp9_highbd_iht4x4_add_sse4.c
@@ -0,0 +1,131 @@
+/*
+ *  Copyright (c) 2018 The WebM project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "./vp9_rtcd.h"
+#include "vp9/common/vp9_idct.h"
+#include "vpx_dsp/x86/highbd_inv_txfm_sse4.h"
+#include "vpx_dsp/x86/inv_txfm_sse2.h"
+#include "vpx_dsp/x86/transpose_sse2.h"
+#include "vpx_dsp/x86/txfm_common_sse2.h"
+
+static INLINE void highbd_iadst4_sse4_1(__m128i *const io) {
+  const __m128i pair_c1 = pair_set_epi32(4 * sinpi_1_9, 0);
+  const __m128i pair_c2 = pair_set_epi32(4 * sinpi_2_9, 0);
+  const __m128i pair_c3 = pair_set_epi32(4 * sinpi_3_9, 0);
+  const __m128i pair_c4 = pair_set_epi32(4 * sinpi_4_9, 0);
+  __m128i s0[2], s1[2], s2[2], s3[2], s4[2], s5[2], s6[2], t0[2], t1[2], t2[2];
+  __m128i temp[2];
+
+  transpose_32bit_4x4(io, io);
+
+  extend_64bit(io[0], temp);
+  s0[0] = _mm_mul_epi32(pair_c1, temp[0]);
+  s0[1] = _mm_mul_epi32(pair_c1, temp[1]);
+  s1[0] = _mm_mul_epi32(pair_c2, temp[0]);
+  s1[1] = _mm_mul_epi32(pair_c2, temp[1]);
+
+  extend_64bit(io[1], temp);
+  s2[0] = _mm_mul_epi32(pair_c3, temp[0]);
+  s2[1] = _mm_mul_epi32(pair_c3, temp[1]);
+
+  extend_64bit(io[2], temp);
+  s3[0] = _mm_mul_epi32(pair_c4, temp[0]);
+  s3[1] = _mm_mul_epi32(pair_c4, temp[1]);
+  s4[0] = _mm_mul_epi32(pair_c1, temp[0]);
+  s4[1] = _mm_mul_epi32(pair_c1, temp[1]);
+
+  extend_64bit(io[3], temp);
+  s5[0] = _mm_mul_epi32(pair_c2, temp[0]);
+  s5[1] = _mm_mul_epi32(pair_c2, temp[1]);
+  s6[0] = _mm_mul_epi32(pair_c4, temp[0]);
+  s6[1] = _mm_mul_epi32(pair_c4, temp[1]);
+
+  t0[0] = _mm_add_epi64(s0[0], s3[0]);
+  t0[1] = _mm_add_epi64(s0[1], s3[1]);
+  t0[0] = _mm_add_epi64(t0[0], s5[0]);
+  t0[1] = _mm_add_epi64(t0[1], s5[1]);
+  t1[0] = _mm_sub_epi64(s1[0], s4[0]);
+  t1[1] = _mm_sub_epi64(s1[1], s4[1]);
+  t1[0] = _mm_sub_epi64(t1[0], s6[0]);
+  t1[1] = _mm_sub_epi64(t1[1], s6[1]);
+  temp[0] = _mm_sub_epi32(io[0], io[2]);
+  temp[0] = _mm_add_epi32(temp[0], io[3]);
+  extend_64bit(temp[0], temp);
+  t2[0] = _mm_mul_epi32(pair_c3, temp[0]);
+  t2[1] = _mm_mul_epi32(pair_c3, temp[1]);
+
+  s0[0] = _mm_add_epi64(t0[0], s2[0]);
+  s0[1] = _mm_add_epi64(t0[1], s2[1]);
+  s1[0] = _mm_add_epi64(t1[0], s2[0]);
+  s1[1] = _mm_add_epi64(t1[1], s2[1]);
+  s3[0] = _mm_add_epi64(t0[0], t1[0]);
+  s3[1] = _mm_add_epi64(t0[1], t1[1]);
+  s3[0] = _mm_sub_epi64(s3[0], s2[0]);
+  s3[1] = _mm_sub_epi64(s3[1], s2[1]);
+
+  s0[0] = dct_const_round_shift_64bit(s0[0]);
+  s0[1] = dct_const_round_shift_64bit(s0[1]);
+  s1[0] = dct_const_round_shift_64bit(s1[0]);
+  s1[1] = dct_const_round_shift_64bit(s1[1]);
+  s2[0] = dct_const_round_shift_64bit(t2[0]);
+  s2[1] = dct_const_round_shift_64bit(t2[1]);
+  s3[0] = dct_const_round_shift_64bit(s3[0]);
+  s3[1] = dct_const_round_shift_64bit(s3[1]);
+  io[0] = pack_4(s0[0], s0[1]);
+  io[1] = pack_4(s1[0], s1[1]);
+  io[2] = pack_4(s2[0], s2[1]);
+  io[3] = pack_4(s3[0], s3[1]);
+}
+
+void vp9_highbd_iht4x4_16_add_sse4_1(const tran_low_t *input, uint16_t *dest,
+                                     int stride, int tx_type, int bd) {
+  __m128i io[4];
+
+  io[0] = _mm_load_si128((const __m128i *)(input + 0));
+  io[1] = _mm_load_si128((const __m128i *)(input + 4));
+  io[2] = _mm_load_si128((const __m128i *)(input + 8));
+  io[3] = _mm_load_si128((const __m128i *)(input + 12));
+
+  if (bd == 8) {
+    __m128i io_short[2];
+
+    io_short[0] = _mm_packs_epi32(io[0], io[1]);
+    io_short[1] = _mm_packs_epi32(io[2], io[3]);
+    if (tx_type == DCT_DCT || tx_type == ADST_DCT) {
+      idct4_sse2(io_short);
+    } else {
+      iadst4_sse2(io_short);
+    }
+    if (tx_type == DCT_DCT || tx_type == DCT_ADST) {
+      idct4_sse2(io_short);
+    } else {
+      iadst4_sse2(io_short);
+    }
+    io_short[0] = _mm_add_epi16(io_short[0], _mm_set1_epi16(8));
+    io_short[1] = _mm_add_epi16(io_short[1], _mm_set1_epi16(8));
+    io[0] = _mm_srai_epi16(io_short[0], 4);
+    io[1] = _mm_srai_epi16(io_short[1], 4);
+  } else {
+    if (tx_type == DCT_DCT || tx_type == ADST_DCT) {
+      highbd_idct4_sse4_1(io);
+    } else {
+      highbd_iadst4_sse4_1(io);
+    }
+    if (tx_type == DCT_DCT || tx_type == DCT_ADST) {
+      highbd_idct4_sse4_1(io);
+    } else {
+      highbd_iadst4_sse4_1(io);
+    }
+    io[0] = wraplow_16bit_shift4(io[0], io[1], _mm_set1_epi32(8));
+    io[1] = wraplow_16bit_shift4(io[2], io[3], _mm_set1_epi32(8));
+  }
+
+  recon_and_store_4x4(io, dest, stride, bd);
+}
diff --git a/vp9/vp9_common.mk b/vp9/vp9_common.mk
index 2fb9a55..9819fb6 100644
--- a/vp9/vp9_common.mk
+++ b/vp9/vp9_common.mk
@@ -80,6 +80,8 @@
 VP9_COMMON_SRCS-$(HAVE_DSPR2) += common/mips/dspr2/vp9_itrans16_dspr2.c
 VP9_COMMON_SRCS-$(HAVE_NEON)  += common/arm/neon/vp9_iht4x4_add_neon.c
 VP9_COMMON_SRCS-$(HAVE_NEON)  += common/arm/neon/vp9_iht8x8_add_neon.c
+else
+VP9_COMMON_SRCS-$(HAVE_SSE4_1) += common/x86/vp9_highbd_iht4x4_add_sse4.c
 endif
 
 $(eval $(call rtcd_h_template,vp9_rtcd,vp9/common/vp9_rtcd_defs.pl))
diff --git a/vpx_dsp/x86/highbd_idct4x4_add_sse4.c b/vpx_dsp/x86/highbd_idct4x4_add_sse4.c
index 38e64f3..fe74d27 100644
--- a/vpx_dsp/x86/highbd_idct4x4_add_sse4.c
+++ b/vpx_dsp/x86/highbd_idct4x4_add_sse4.c
@@ -16,28 +16,6 @@
 #include "vpx_dsp/x86/inv_txfm_sse2.h"
 #include "vpx_dsp/x86/transpose_sse2.h"
 
-static INLINE void highbd_idct4(__m128i *const io) {
-  __m128i temp[2], step[4];
-
-  transpose_32bit_4x4(io, io);
-
-  // stage 1
-  temp[0] = _mm_add_epi32(io[0], io[2]);  // input[0] + input[2]
-  extend_64bit(temp[0], temp);
-  step[0] = multiplication_round_shift_sse4_1(temp, cospi_16_64);
-  temp[0] = _mm_sub_epi32(io[0], io[2]);  // input[0] - input[2]
-  extend_64bit(temp[0], temp);
-  step[1] = multiplication_round_shift_sse4_1(temp, cospi_16_64);
-  highbd_butterfly_sse4_1(io[1], io[3], cospi_24_64, cospi_8_64, &step[2],
-                          &step[3]);
-
-  // stage 2
-  io[0] = _mm_add_epi32(step[0], step[3]);  // step[0] + step[3]
-  io[1] = _mm_add_epi32(step[1], step[2]);  // step[1] + step[2]
-  io[2] = _mm_sub_epi32(step[1], step[2]);  // step[1] - step[2]
-  io[3] = _mm_sub_epi32(step[0], step[3]);  // step[0] - step[3]
-}
-
 void vpx_highbd_idct4x4_16_add_sse4_1(const tran_low_t *input, uint16_t *dest,
                                       int stride, int bd) {
   __m128i io[4];
@@ -59,8 +37,8 @@
     io[0] = _mm_srai_epi16(io_short[0], 4);
     io[1] = _mm_srai_epi16(io_short[1], 4);
   } else {
-    highbd_idct4(io);
-    highbd_idct4(io);
+    highbd_idct4_sse4_1(io);
+    highbd_idct4_sse4_1(io);
     io[0] = wraplow_16bit_shift4(io[0], io[1], _mm_set1_epi32(8));
     io[1] = wraplow_16bit_shift4(io[2], io[3], _mm_set1_epi32(8));
   }
diff --git a/vpx_dsp/x86/highbd_inv_txfm_sse2.h b/vpx_dsp/x86/highbd_inv_txfm_sse2.h
index e0f7495..c89666b 100644
--- a/vpx_dsp/x86/highbd_inv_txfm_sse2.h
+++ b/vpx_dsp/x86/highbd_inv_txfm_sse2.h
@@ -19,6 +19,10 @@
 #include "vpx_dsp/x86/transpose_sse2.h"
 #include "vpx_dsp/x86/txfm_common_sse2.h"
 
+// Note: There is no 64-bit bit-level shifting SIMD instruction. All
+// coefficients are left shifted by 2, so that dct_const_round_shift() can be
+// done by right shifting 2 bytes.
+
 static INLINE void extend_64bit(const __m128i in,
                                 __m128i *const out /*out[2]*/) {
   out[0] = _mm_unpacklo_epi32(in, in);  // 0, 0, 1, 1
diff --git a/vpx_dsp/x86/highbd_inv_txfm_sse4.h b/vpx_dsp/x86/highbd_inv_txfm_sse4.h
index 9c8eef4..435934f 100644
--- a/vpx_dsp/x86/highbd_inv_txfm_sse4.h
+++ b/vpx_dsp/x86/highbd_inv_txfm_sse4.h
@@ -84,4 +84,26 @@
   *out1 = multiplication_round_shift_sse4_1(temp, c1);
 }
 
+static INLINE void highbd_idct4_sse4_1(__m128i *const io) {
+  __m128i temp[2], step[4];
+
+  transpose_32bit_4x4(io, io);
+
+  // stage 1
+  temp[0] = _mm_add_epi32(io[0], io[2]);  // input[0] + input[2]
+  extend_64bit(temp[0], temp);
+  step[0] = multiplication_round_shift_sse4_1(temp, cospi_16_64);
+  temp[0] = _mm_sub_epi32(io[0], io[2]);  // input[0] - input[2]
+  extend_64bit(temp[0], temp);
+  step[1] = multiplication_round_shift_sse4_1(temp, cospi_16_64);
+  highbd_butterfly_sse4_1(io[1], io[3], cospi_24_64, cospi_8_64, &step[2],
+                          &step[3]);
+
+  // stage 2
+  io[0] = _mm_add_epi32(step[0], step[3]);  // step[0] + step[3]
+  io[1] = _mm_add_epi32(step[1], step[2]);  // step[1] + step[2]
+  io[2] = _mm_sub_epi32(step[1], step[2]);  // step[1] - step[2]
+  io[3] = _mm_sub_epi32(step[0], step[3]);  // step[0] - step[3]
+}
+
 #endif  // VPX_DSP_X86_HIGHBD_INV_TXFM_SSE4_H_