Fix the use of HBD buffers for 8bit input

BUG=aomedia:1563

Change-Id: Ie9becbf4b6d4230c784a1a7c86f39e66bf8307a1
diff --git a/aom_dsp/fastssim.c b/aom_dsp/fastssim.c
index 1d681fc..87060aa 100644
--- a/aom_dsp/fastssim.c
+++ b/aom_dsp/fastssim.c
@@ -138,8 +138,8 @@
 
 static void fs_downsample_level0(fs_ctx *_ctx, const uint8_t *_src1,
                                  int _s1ystride, const uint8_t *_src2,
-                                 int _s2ystride, int _w, int _h, uint32_t bd,
-                                 uint32_t shift) {
+                                 int _s2ystride, int _w, int _h, uint32_t shift,
+                                 int buf_is_hbd) {
   uint32_t *dst1;
   uint32_t *dst2;
   int w;
@@ -160,7 +160,7 @@
       int i1;
       i0 = 2 * i;
       i1 = FS_MINI(i0 + 1, _w);
-      if (bd == 8 && shift == 0) {
+      if (!buf_is_hbd) {
         dst1[j * w + i] =
             _src1[j0 * _s1ystride + i0] + _src1[j0 * _s1ystride + i1] +
             _src1[j1 * _s1ystride + i0] + _src1[j1 * _s1ystride + i1];
@@ -439,14 +439,14 @@
 
 static double calc_ssim(const uint8_t *_src, int _systride, const uint8_t *_dst,
                         int _dystride, int _w, int _h, uint32_t _bd,
-                        uint32_t _shift) {
+                        uint32_t _shift, int buf_is_hbd) {
   fs_ctx ctx;
   double ret;
   int l;
   ret = 1;
   fs_ctx_init(&ctx, _w, _h, FS_NLEVELS);
-  fs_downsample_level0(&ctx, _src, _systride, _dst, _dystride, _w, _h, _bd,
-                       _shift);
+  fs_downsample_level0(&ctx, _src, _systride, _dst, _dystride, _w, _h, _shift,
+                       buf_is_hbd);
   for (l = 0; l < FS_NLEVELS - 1; l++) {
     fs_calc_structure(&ctx, l, _bd);
     ret *= fs_average(&ctx, l);
@@ -467,18 +467,19 @@
   uint32_t bd_shift = 0;
   aom_clear_system_state();
   assert(bd >= in_bd);
-
+  assert(source->flags == dest->flags);
+  int buf_is_hbd = source->flags & YV12_FLAG_HIGHBITDEPTH;
   bd_shift = bd - in_bd;
 
   *ssim_y = calc_ssim(source->y_buffer, source->y_stride, dest->y_buffer,
                       dest->y_stride, source->y_crop_width,
-                      source->y_crop_height, in_bd, bd_shift);
+                      source->y_crop_height, in_bd, bd_shift, buf_is_hbd);
   *ssim_u = calc_ssim(source->u_buffer, source->uv_stride, dest->u_buffer,
                       dest->uv_stride, source->uv_crop_width,
-                      source->uv_crop_height, in_bd, bd_shift);
+                      source->uv_crop_height, in_bd, bd_shift, buf_is_hbd);
   *ssim_v = calc_ssim(source->v_buffer, source->uv_stride, dest->v_buffer,
                       dest->uv_stride, source->uv_crop_width,
-                      source->uv_crop_height, in_bd, bd_shift);
+                      source->uv_crop_height, in_bd, bd_shift, buf_is_hbd);
   ssimv = (*ssim_y) * .8 + .1 * ((*ssim_u) + (*ssim_v));
   return convert_ssim_db(ssimv, 1.0);
 }
diff --git a/aom_dsp/psnrhvs.c b/aom_dsp/psnrhvs.c
index 324b387..90743a4 100644
--- a/aom_dsp/psnrhvs.c
+++ b/aom_dsp/psnrhvs.c
@@ -121,7 +121,8 @@
 static double calc_psnrhvs(const unsigned char *src, int _systride,
                            const unsigned char *dst, int _dystride, double _par,
                            int _w, int _h, int _step, const double _csf[8][8],
-                           uint32_t bit_depth, uint32_t _shift) {
+                           uint32_t bit_depth, uint32_t _shift,
+                           int buf_is_hbd) {
   double ret;
   const uint8_t *_src8 = src;
   const uint8_t *_dst8 = dst;
@@ -176,7 +177,7 @@
       for (i = 0; i < 8; i++) {
         for (j = 0; j < 8; j++) {
           int sub = ((i & 12) >> 2) + ((j & 12) >> 1);
-          if (bit_depth == 8 && _shift == 0) {
+          if (!buf_is_hbd) {
             dct_s[i * 8 + j] = _src8[(y + i) * _systride + (j + x)];
             dct_d[i * 8 + j] = _dst8[(y + i) * _dystride + (j + x)];
           } else if (bit_depth == 10 || bit_depth == 12) {
@@ -254,21 +255,25 @@
   const int step = 7;
   uint32_t bd_shift = 0;
   aom_clear_system_state();
-
   assert(bd == 8 || bd == 10 || bd == 12);
   assert(bd >= in_bd);
+  assert(src->flags == dst->flags);
+  int buf_is_hbd = src->flags & YV12_FLAG_HIGHBITDEPTH;
 
   bd_shift = bd - in_bd;
 
-  *y_psnrhvs = calc_psnrhvs(src->y_buffer, src->y_stride, dst->y_buffer,
-                            dst->y_stride, par, src->y_crop_width,
-                            src->y_crop_height, step, csf_y, bd, bd_shift);
-  *u_psnrhvs = calc_psnrhvs(src->u_buffer, src->uv_stride, dst->u_buffer,
-                            dst->uv_stride, par, src->uv_crop_width,
-                            src->uv_crop_height, step, csf_cb420, bd, bd_shift);
-  *v_psnrhvs = calc_psnrhvs(src->v_buffer, src->uv_stride, dst->v_buffer,
-                            dst->uv_stride, par, src->uv_crop_width,
-                            src->uv_crop_height, step, csf_cr420, bd, bd_shift);
+  *y_psnrhvs =
+      calc_psnrhvs(src->y_buffer, src->y_stride, dst->y_buffer, dst->y_stride,
+                   par, src->y_crop_width, src->y_crop_height, step, csf_y, bd,
+                   bd_shift, buf_is_hbd);
+  *u_psnrhvs =
+      calc_psnrhvs(src->u_buffer, src->uv_stride, dst->u_buffer, dst->uv_stride,
+                   par, src->uv_crop_width, src->uv_crop_height, step,
+                   csf_cb420, bd, bd_shift, buf_is_hbd);
+  *v_psnrhvs =
+      calc_psnrhvs(src->v_buffer, src->uv_stride, dst->v_buffer, dst->uv_stride,
+                   par, src->uv_crop_width, src->uv_crop_height, step,
+                   csf_cr420, bd, bd_shift, buf_is_hbd);
   psnrhvs = (*y_psnrhvs) * .8 + .1 * ((*u_psnrhvs) + (*v_psnrhvs));
   return convert_score_db(psnrhvs, 1.0, in_bd);
 }