daala-dist: high bit depth support

Change-Id: Idafef140d3425a9a9f66cb8864a804c4d2a89a70
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index a3c0844..be12985 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -728,58 +728,105 @@
   return sum;
 }
 
-int64_t av1_daala_dist(const uint8_t *src, int src_stride, const uint8_t *dst,
-                       int dst_stride, int bsw, int bsh, int visible_w,
-                       int visible_h, int qm, int use_activity_masking,
-                       int qindex) {
+int64_t av1_daala_dist(const MACROBLOCKD *xd, const uint8_t *src,
+                       int src_stride, const uint8_t *dst, int dst_stride,
+                       int bsw, int bsh, int visible_w, int visible_h, int qm,
+                       int use_activity_masking, int qindex) {
   int i, j;
   int64_t d;
   DECLARE_ALIGNED(16, od_coeff, orig[MAX_TX_SQUARE]);
   DECLARE_ALIGNED(16, od_coeff, rec[MAX_TX_SQUARE]);
-
+#if !CONFIG_HIGHBITDEPTH
+  (void)xd;
+#endif
   assert(qm == OD_HVS_QM);
 
-  for (j = 0; j < bsh; j++)
-    for (i = 0; i < bsw; i++) orig[j * bsw + i] = src[j * src_stride + i];
-
-  if ((bsw == visible_w) && (bsh == visible_h)) {
+#if CONFIG_HIGHBITDEPTH
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
     for (j = 0; j < bsh; j++)
-      for (i = 0; i < bsw; i++) rec[j * bsw + i] = dst[j * dst_stride + i];
-  } else {
-    for (j = 0; j < visible_h; j++)
-      for (i = 0; i < visible_w; i++)
-        rec[j * bsw + i] = dst[j * dst_stride + i];
+      for (i = 0; i < bsw; i++)
+        orig[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
 
-    if (visible_w < bsw) {
+    if ((bsw == visible_w) && (bsh == visible_h)) {
       for (j = 0; j < bsh; j++)
-        for (i = visible_w; i < bsw; i++)
-          rec[j * bsw + i] = src[j * src_stride + i];
-    }
+        for (i = 0; i < bsw; i++)
+          rec[j * bsw + i] = CONVERT_TO_SHORTPTR(dst)[j * dst_stride + i];
+    } else {
+      for (j = 0; j < visible_h; j++)
+        for (i = 0; i < visible_w; i++)
+          rec[j * bsw + i] = CONVERT_TO_SHORTPTR(dst)[j * dst_stride + i];
 
-    if (visible_h < bsh) {
-      for (j = visible_h; j < bsh; j++)
-        for (i = 0; i < bsw; i++) rec[j * bsw + i] = src[j * src_stride + i];
+      if (visible_w < bsw) {
+        for (j = 0; j < bsh; j++)
+          for (i = visible_w; i < bsw; i++)
+            rec[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
+      }
+
+      if (visible_h < bsh) {
+        for (j = visible_h; j < bsh; j++)
+          for (i = 0; i < bsw; i++)
+            rec[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
+      }
     }
+  } else {
+#endif
+    for (j = 0; j < bsh; j++)
+      for (i = 0; i < bsw; i++) orig[j * bsw + i] = src[j * src_stride + i];
+
+    if ((bsw == visible_w) && (bsh == visible_h)) {
+      for (j = 0; j < bsh; j++)
+        for (i = 0; i < bsw; i++) rec[j * bsw + i] = dst[j * dst_stride + i];
+    } else {
+      for (j = 0; j < visible_h; j++)
+        for (i = 0; i < visible_w; i++)
+          rec[j * bsw + i] = dst[j * dst_stride + i];
+
+      if (visible_w < bsw) {
+        for (j = 0; j < bsh; j++)
+          for (i = visible_w; i < bsw; i++)
+            rec[j * bsw + i] = src[j * src_stride + i];
+      }
+
+      if (visible_h < bsh) {
+        for (j = visible_h; j < bsh; j++)
+          for (i = 0; i < bsw; i++) rec[j * bsw + i] = src[j * src_stride + i];
+      }
+    }
+#if CONFIG_HIGHBITDEPTH
   }
+#endif  // CONFIG_HIGHBITDEPTH
+
   d = (int64_t)od_compute_dist(qm, use_activity_masking, orig, rec, bsw, bsh,
                                qindex);
   return d;
 }
 
-static int64_t av1_daala_dist_diff(const uint8_t *src, int src_stride,
-                                   const int16_t *diff, int dst_stride, int bsw,
-                                   int bsh, int visible_w, int visible_h,
-                                   int qm, int use_activity_masking,
-                                   int qindex) {
+static int64_t av1_daala_dist_diff(const MACROBLOCKD *xd, const uint8_t *src,
+                                   int src_stride, const int16_t *diff,
+                                   int dst_stride, int bsw, int bsh,
+                                   int visible_w, int visible_h, int qm,
+                                   int use_activity_masking, int qindex) {
   int i, j;
   int64_t d;
   DECLARE_ALIGNED(16, od_coeff, orig[MAX_TX_SQUARE]);
   DECLARE_ALIGNED(16, od_coeff, diff32[MAX_TX_SQUARE]);
-
+#if !CONFIG_HIGHBITDEPTH
+  (void)xd;
+#endif
   assert(qm == OD_HVS_QM);
 
-  for (j = 0; j < bsh; j++)
-    for (i = 0; i < bsw; i++) orig[j * bsw + i] = src[j * src_stride + i];
+#if CONFIG_HIGHBITDEPTH
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+    for (j = 0; j < bsh; j++)
+      for (i = 0; i < bsw; i++)
+        orig[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
+  } else {
+#endif
+    for (j = 0; j < bsh; j++)
+      for (i = 0; i < bsw; i++) orig[j * bsw + i] = src[j * src_stride + i];
+#if CONFIG_HIGHBITDEPTH
+  }
+#endif  // CONFIG_HIGHBITDEPTH
 
   if ((bsw == visible_w) && (bsh == visible_h)) {
     for (j = 0; j < bsh; j++)
@@ -799,6 +846,7 @@
         for (i = 0; i < bsw; i++) diff32[j * bsw + i] = 0;
     }
   }
+
   d = (int64_t)od_compute_dist_diff(qm, use_activity_masking, orig, diff32, bsw,
                                     bsh, qindex);
 
@@ -1455,9 +1503,9 @@
 
 #if CONFIG_DAALA_DIST
   if (plane == 0 && txb_cols >= 8 && txb_rows >= 8)
-    return av1_daala_dist(src, src_stride, dst, dst_stride, txb_cols, txb_rows,
-                          visible_cols, visible_rows, qm, use_activity_masking,
-                          x->qindex);
+    return av1_daala_dist(xd, src, src_stride, dst, dst_stride, txb_cols,
+                          txb_rows, visible_cols, visible_rows, qm,
+                          use_activity_masking, x->qindex);
 #endif  // CONFIG_DAALA_DIST
 
 #if CONFIG_EXT_TX && CONFIG_RECT_TX && CONFIG_RECT_TX_EXT
@@ -1509,9 +1557,9 @@
 
 #if CONFIG_DAALA_DIST
   if (plane == 0 && txb_width >= 8 && txb_height >= 8)
-    return av1_daala_dist_diff(src, src_stride, diff, diff_stride, txb_width,
-                               txb_height, visible_cols, visible_rows, qm,
-                               use_activity_masking, x->qindex);
+    return av1_daala_dist_diff(
+        xd, src, src_stride, diff, diff_stride, txb_width, txb_height,
+        visible_cols, visible_rows, qm, use_activity_masking, x->qindex);
   else
 #endif
     return aom_sum_squares_2d_i16(diff, diff_stride, visible_cols,
@@ -1648,9 +1696,20 @@
           int16_t *pred = &pd->pred[pred_idx];
           int i, j;
 
-          for (j = 0; j < bsh; j++)
-            for (i = 0; i < bsw; i++)
-              pred[j * pred_stride + i] = recon[j * MAX_TX_SIZE + i];
+#if CONFIG_HIGHBITDEPTH
+          if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+            for (j = 0; j < bsh; j++)
+              for (i = 0; i < bsw; i++)
+                pred[j * pred_stride + i] =
+                    CONVERT_TO_SHORTPTR(recon)[j * MAX_TX_SIZE + i];
+          } else {
+#endif
+            for (j = 0; j < bsh; j++)
+              for (i = 0; i < bsw; i++)
+                pred[j * pred_stride + i] = recon[j * MAX_TX_SIZE + i];
+#if CONFIG_HIGHBITDEPTH
+          }
+#endif  // CONFIG_HIGHBITDEPTH
         }
 #endif  // CONFIG_DAALA_DIST
         *out_dist =
@@ -1842,15 +1901,35 @@
   assert((bw & 0x07) == 0);
   assert((bh & 0x07) == 0);
 
-  DECLARE_ALIGNED(16, uint8_t, pred8[MAX_SB_SQUARE]);
+#if CONFIG_HIGHBITDEPTH
+  uint8_t *pred8;
+  DECLARE_ALIGNED(16, uint16_t, pred16[MAX_TX_SQUARE]);
 
-  for (j = 0; j < bh; j++)
-    for (i = 0; i < bw; i++) pred8[j * bw + i] = pred[j * bw + i];
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+    pred8 = CONVERT_TO_BYTEPTR(pred16);
+  else
+    pred8 = (uint8_t *)pred16;
+#else
+  DECLARE_ALIGNED(16, uint8_t, pred8[MAX_TX_SQUARE]);
+#endif  // CONFIG_HIGHBITDEPTH
 
-  tmp1 = av1_daala_dist(src, src_stride, pred8, bw, bw, bh, bw, bh, qm,
+#if CONFIG_HIGHBITDEPTH
+  if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+    for (j = 0; j < bh; j++)
+      for (i = 0; i < bw; i++)
+        CONVERT_TO_SHORTPTR(pred8)[j * bw + i] = pred[j * bw + i];
+  } else {
+#endif
+    for (j = 0; j < bh; j++)
+      for (i = 0; i < bw; i++) pred8[j * bw + i] = pred[j * bw + i];
+#if CONFIG_HIGHBITDEPTH
+  }
+#endif  // CONFIG_HIGHBITDEPTH
+
+  tmp1 = av1_daala_dist(xd, src, src_stride, pred8, bw, bw, bh, bw, bh, qm,
                         use_activity_masking, qindex);
-  tmp2 = av1_daala_dist(src, src_stride, dst, dst_stride, bw, bh, bw, bh, qm,
-                        use_activity_masking, qindex);
+  tmp2 = av1_daala_dist(xd, src, src_stride, dst, dst_stride, bw, bh, bw, bh,
+                        qm, use_activity_masking, qindex);
 
   if (!is_inter_block(mbmi)) {
     args->rd_stats.sse = (int64_t)tmp1 * 16;
@@ -3385,9 +3464,10 @@
     use_activity_masking = mb->daala_enc.use_activity_masking;
 #endif  // CONFIG_PVQ
     // Daala-defined distortion computed for the block of 8x8 pixels
-    total_distortion = av1_daala_dist(src, src_stride, dst, dst_stride, 8, 8, 8,
-                                      8, qm, use_activity_masking, mb->qindex)
-                       << 4;
+    total_distortion =
+        av1_daala_dist(xd, src, src_stride, dst, dst_stride, 8, 8, 8, 8, qm,
+                       use_activity_masking, mb->qindex)
+        << 4;
   }
 #endif  // CONFIG_DAALA_DIST
   // Add in the cost of the transform type
@@ -4147,10 +4227,20 @@
       int16_t *decoded = &pd->pred[pred_idx];
       int i, j;
 
-      // TODO(yushin): HBD support
-      for (j = 0; j < bh; j++)
-        for (i = 0; i < bw; i++)
-          decoded[j * pred_stride + i] = rec_buffer[j * MAX_TX_SIZE + i];
+#if CONFIG_HIGHBITDEPTH
+      if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+        for (j = 0; j < bh; j++)
+          for (i = 0; i < bw; i++)
+            decoded[j * pred_stride + i] =
+                CONVERT_TO_SHORTPTR(rec_buffer)[j * MAX_TX_SIZE + i];
+      } else {
+#endif
+        for (j = 0; j < bh; j++)
+          for (i = 0; i < bw; i++)
+            decoded[j * pred_stride + i] = rec_buffer[j * MAX_TX_SIZE + i];
+#if CONFIG_HIGHBITDEPTH
+      }
+#endif  // CONFIG_HIGHBITDEPTH
     }
 #endif  // CONFIG_DAALA_DIST
     tmp = pixel_dist(cpi, x, plane, src, src_stride, rec_buffer, MAX_TX_SIZE,
@@ -4318,35 +4408,74 @@
       int use_activity_masking = 0;
       int row, col;
 
+#if CONFIG_HIGHBITDEPTH
+      uint8_t *pred8;
+      DECLARE_ALIGNED(16, uint16_t, pred8_16[8 * 8]);
+#else
       DECLARE_ALIGNED(16, uint8_t, pred8[8 * 8]);
+#endif  // CONFIG_HIGHBITDEPTH
 
 #if CONFIG_PVQ
       use_activity_masking = x->daala_enc.use_activity_masking;
 #endif
-      daala_dist = av1_daala_dist(src, src_stride, dst, dst_stride, 8, 8, 8, 8,
-                                  qm, use_activity_masking, qindex) *
+      daala_dist = av1_daala_dist(xd, src, src_stride, dst, dst_stride, 8, 8, 8,
+                                  8, qm, use_activity_masking, qindex) *
                    16;
       sum_rd_stats.sse = daala_dist;
 
-      for (row = 0; row < 2; ++row) {
-        for (col = 0; col < 2; ++col) {
-          int idx = row * 2 + col;
-          int eob = sub8x8_eob[idx];
+#if CONFIG_HIGHBITDEPTH
+      if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
+        pred8 = CONVERT_TO_BYTEPTR(pred8_16);
+      else
+        pred8 = (uint8_t *)pred8_16;
+#endif
 
-          if (eob > 0) {
-            for (j = 0; j < 4; j++)
-              for (i = 0; i < 4; i++)
-                pred8[(row * 4 + j) * 8 + 4 * col + i] =
-                    pred[(row * 4 + j) * pred_stride + 4 * col + i];
-          } else {
-            for (j = 0; j < 4; j++)
-              for (i = 0; i < 4; i++)
-                pred8[(row * 4 + j) * 8 + 4 * col + i] =
-                    dst[(row * 4 + j) * dst_stride + 4 * col + i];
+#if CONFIG_HIGHBITDEPTH
+      if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
+        for (row = 0; row < 2; ++row) {
+          for (col = 0; col < 2; ++col) {
+            int idx = row * 2 + col;
+            int eob = sub8x8_eob[idx];
+
+            if (eob > 0) {
+              for (j = 0; j < 4; j++)
+                for (i = 0; i < 4; i++)
+                  CONVERT_TO_SHORTPTR(pred8)
+                  [(row * 4 + j) * 8 + 4 * col + i] =
+                      pred[(row * 4 + j) * pred_stride + 4 * col + i];
+            } else {
+              for (j = 0; j < 4; j++)
+                for (i = 0; i < 4; i++)
+                  CONVERT_TO_SHORTPTR(pred8)
+                  [(row * 4 + j) * 8 + 4 * col + i] = CONVERT_TO_SHORTPTR(
+                      dst)[(row * 4 + j) * dst_stride + 4 * col + i];
+            }
           }
         }
+      } else {
+#endif
+        for (row = 0; row < 2; ++row) {
+          for (col = 0; col < 2; ++col) {
+            int idx = row * 2 + col;
+            int eob = sub8x8_eob[idx];
+
+            if (eob > 0) {
+              for (j = 0; j < 4; j++)
+                for (i = 0; i < 4; i++)
+                  pred8[(row * 4 + j) * 8 + 4 * col + i] =
+                      pred[(row * 4 + j) * pred_stride + 4 * col + i];
+            } else {
+              for (j = 0; j < 4; j++)
+                for (i = 0; i < 4; i++)
+                  pred8[(row * 4 + j) * 8 + 4 * col + i] =
+                      dst[(row * 4 + j) * dst_stride + 4 * col + i];
+            }
+          }
+        }
+#if CONFIG_HIGHBITDEPTH
       }
-      daala_dist = av1_daala_dist(src, src_stride, pred8, 8, 8, 8, 8, 8, qm,
+#endif  // CONFIG_HIGHBITDEPTH
+      daala_dist = av1_daala_dist(xd, src, src_stride, pred8, 8, 8, 8, 8, 8, qm,
                                   use_activity_masking, qindex) *
                    16;
       sum_rd_stats.dist = daala_dist;