palette-delta-encoding experiment

Transmit palette colors with delta encoding.
Coding gain on scrren_content testset:
overall 0.67%  keyframe 1.37%

Change-Id: I72ce9061dfddf933e9f7530f069955afcb07edf8
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index ce518e6..df8f007 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -51,6 +51,9 @@
 #include "av1/encoder/cost.h"
 #include "av1/encoder/encodemv.h"
 #include "av1/encoder/mcomp.h"
+#if CONFIG_PALETTE && CONFIG_PALETTE_DELTA_ENCODING
+#include "av1/encoder/palette.h"
+#endif  // CONFIG_PALETTE && CONFIG_PALETTE_DELTA_ENCODING
 #include "av1/encoder/segmentation.h"
 #include "av1/encoder/subexp.h"
 #include "av1/encoder/tokenize.h"
@@ -1368,6 +1371,82 @@
 }
 
 #if CONFIG_PALETTE
+#if CONFIG_PALETTE_DELTA_ENCODING
+// Write luma palette color values with delta encoding. Write the first value as
+// literal, and the deltas between each value and the previous one. The luma
+// palette is sorted so each delta is larger than 0.
+static void write_palette_colors_y(const PALETTE_MODE_INFO *const pmi,
+                                   int bit_depth, aom_writer *w) {
+  const int n = pmi->palette_size[0];
+  int min_bits, i;
+  int bits = av1_get_palette_delta_bits_y(pmi, bit_depth, &min_bits);
+  aom_write_literal(w, bits - min_bits, 2);
+  aom_write_literal(w, pmi->palette_colors[0], bit_depth);
+  for (i = 1; i < n; ++i) {
+    aom_write_literal(
+        w, pmi->palette_colors[i] - pmi->palette_colors[i - 1] - 1, bits);
+    bits =
+        AOMMIN(bits, av1_ceil_log2((1 << bit_depth) - pmi->palette_colors[i]));
+  }
+}
+
+// Write chroma palette color values. Use delta encoding for u channel as its
+// palette is sorted. For v channel, either use delta encoding or transmit
+// raw values directly, whichever costs less.
+static void write_palette_colors_uv(const PALETTE_MODE_INFO *const pmi,
+                                    int bit_depth, aom_writer *w) {
+  int i;
+  const int n = pmi->palette_size[1];
+#if CONFIG_HIGHBITDEPTH
+  const uint16_t *colors_u = pmi->palette_colors + PALETTE_MAX_SIZE;
+  const uint16_t *colors_v = pmi->palette_colors + 2 * PALETTE_MAX_SIZE;
+#else
+  const uint8_t *colors_u = pmi->palette_colors + PALETTE_MAX_SIZE;
+  const uint8_t *colors_v = pmi->palette_colors + 2 * PALETTE_MAX_SIZE;
+#endif  // CONFIG_HIGHBITDEPTH
+  // U channel colors.
+  int min_bits_u = 0;
+  int bits_u = av1_get_palette_delta_bits_u(pmi, bit_depth, &min_bits_u);
+  aom_write_literal(w, bits_u - min_bits_u, 2);
+  aom_write_literal(w, colors_u[0], bit_depth);
+  for (i = 1; i < n; ++i) {
+    aom_write_literal(w, colors_u[i] - colors_u[i - 1], bits_u);
+    bits_u = AOMMIN(bits_u, av1_ceil_log2(1 + (1 << bit_depth) - colors_u[i]));
+  }
+  // V channel colors.
+  const int max_val = 1 << bit_depth;
+  int zero_count = 0, min_bits_v = 0;
+  int bits_v =
+      av1_get_palette_delta_bits_v(pmi, bit_depth, &zero_count, &min_bits_v);
+  const int rate_using_delta =
+      2 + bit_depth + (bits_v + 1) * (n - 1) - zero_count;
+  const int rate_using_raw = bit_depth * n;
+  if (rate_using_delta < rate_using_raw) {  // delta encoding
+    aom_write_bit(w, 1);
+    aom_write_literal(w, bits_v - min_bits_v, 2);
+    aom_write_literal(w, colors_v[0], bit_depth);
+    for (i = 1; i < n; ++i) {
+      if (colors_v[i] == colors_v[i - 1]) {  // No need to signal sign bit.
+        aom_write_literal(w, 0, bits_v);
+        continue;
+      }
+      const int delta = abs((int)colors_v[i] - colors_v[i - 1]);
+      const int sign_bit = colors_v[i] < colors_v[i - 1];
+      if (delta <= max_val - delta) {
+        aom_write_literal(w, delta, bits_v);
+        aom_write_bit(w, sign_bit);
+      } else {
+        aom_write_literal(w, max_val - delta, bits_v);
+        aom_write_bit(w, !sign_bit);
+      }
+    }
+  } else {  // Transmit raw values.
+    aom_write_bit(w, 0);
+    for (i = 0; i < n; ++i) aom_write_literal(w, colors_v[i], bit_depth);
+  }
+}
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+
 static void write_palette_mode_info(const AV1_COMMON *cm, const MACROBLOCKD *xd,
                                     const MODE_INFO *const mi, aom_writer *w) {
   const MB_MODE_INFO *const mbmi = &mi->mbmi;
@@ -1375,7 +1454,6 @@
   const MODE_INFO *const left_mi = xd->left_mi;
   const BLOCK_SIZE bsize = mbmi->sb_type;
   const PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
-  int i;
 
   if (mbmi->mode == DC_PRED) {
     const int n = pmi->palette_size[0];
@@ -1393,8 +1471,13 @@
       av1_write_token(w, av1_palette_size_tree,
                       av1_default_palette_y_size_prob[bsize - BLOCK_8X8],
                       &palette_size_encodings[n - PALETTE_MIN_SIZE]);
+#if CONFIG_PALETTE_DELTA_ENCODING
+      write_palette_colors_y(pmi, cm->bit_depth, w);
+#else
+      int i;
       for (i = 0; i < n; ++i)
         aom_write_literal(w, pmi->palette_colors[i], cm->bit_depth);
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
       write_uniform(w, n, pmi->palette_first_color_idx[0]);
     }
   }
@@ -1407,12 +1490,17 @@
       av1_write_token(w, av1_palette_size_tree,
                       av1_default_palette_uv_size_prob[bsize - BLOCK_8X8],
                       &palette_size_encodings[n - PALETTE_MIN_SIZE]);
+#if CONFIG_PALETTE_DELTA_ENCODING
+      write_palette_colors_uv(pmi, cm->bit_depth, w);
+#else
+      int i;
       for (i = 0; i < n; ++i) {
         aom_write_literal(w, pmi->palette_colors[PALETTE_MAX_SIZE + i],
                           cm->bit_depth);
         aom_write_literal(w, pmi->palette_colors[2 * PALETTE_MAX_SIZE + i],
                           cm->bit_depth);
       }
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
       write_uniform(w, n, pmi->palette_first_color_idx[1]);
     }
   }
diff --git a/av1/encoder/palette.c b/av1/encoder/palette.c
index 248681e..355141d 100644
--- a/av1/encoder/palette.c
+++ b/av1/encoder/palette.c
@@ -11,6 +11,8 @@
 
 #include <math.h>
 #include <stdlib.h>
+
+#include "av1/encoder/cost.h"
 #include "av1/encoder/palette.h"
 
 static float calc_dist(const float *p1, const float *p2, int dim) {
@@ -164,6 +166,89 @@
   return n;
 }
 
+#if CONFIG_PALETTE_DELTA_ENCODING
+int av1_get_palette_delta_bits_y(const PALETTE_MODE_INFO *const pmi,
+                                 int bit_depth, int *min_bits) {
+  const int n = pmi->palette_size[0];
+  int max_d = 0, i;
+  *min_bits = bit_depth - 3;
+  for (i = 1; i < n; ++i) {
+    const int delta = pmi->palette_colors[i] - pmi->palette_colors[i - 1];
+    assert(delta > 0);
+    if (delta > max_d) max_d = delta;
+  }
+  return AOMMAX(av1_ceil_log2(max_d), *min_bits);
+}
+
+int av1_get_palette_delta_bits_u(const PALETTE_MODE_INFO *const pmi,
+                                 int bit_depth, int *min_bits) {
+  const int n = pmi->palette_size[1];
+  int max_d = 0, i;
+  *min_bits = bit_depth - 3;
+  for (i = 1; i < n; ++i) {
+    const int delta = pmi->palette_colors[PALETTE_MAX_SIZE + i] -
+                      pmi->palette_colors[PALETTE_MAX_SIZE + i - 1];
+    assert(delta >= 0);
+    if (delta > max_d) max_d = delta;
+  }
+  return AOMMAX(av1_ceil_log2(max_d + 1), *min_bits);
+}
+
+int av1_get_palette_delta_bits_v(const PALETTE_MODE_INFO *const pmi,
+                                 int bit_depth, int *zero_count,
+                                 int *min_bits) {
+  const int n = pmi->palette_size[1];
+  const int max_val = 1 << bit_depth;
+  int max_d = 0, i;
+  *min_bits = bit_depth - 4;
+  *zero_count = 0;
+  for (i = 1; i < n; ++i) {
+    const int delta = pmi->palette_colors[2 * PALETTE_MAX_SIZE + i] -
+                      pmi->palette_colors[2 * PALETTE_MAX_SIZE + i - 1];
+    const int v = abs(delta);
+    const int d = AOMMIN(v, max_val - v);
+    if (d > max_d) max_d = d;
+    if (d == 0) ++(*zero_count);
+  }
+  return AOMMAX(av1_ceil_log2(max_d + 1), *min_bits);
+}
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+
+int av1_palette_color_cost_y(const PALETTE_MODE_INFO *const pmi,
+                             int bit_depth) {
+  const int n = pmi->palette_size[0];
+#if CONFIG_PALETTE_DELTA_ENCODING
+  int min_bits = 0;
+  const int bits = av1_get_palette_delta_bits_y(pmi, bit_depth, &min_bits);
+  return av1_cost_bit(128, 0) * (2 + bit_depth + bits * (n - 1));
+#else
+  return bit_depth * n * av1_cost_bit(128, 0);
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+}
+
+int av1_palette_color_cost_uv(const PALETTE_MODE_INFO *const pmi,
+                              int bit_depth) {
+  const int n = pmi->palette_size[1];
+#if CONFIG_PALETTE_DELTA_ENCODING
+  int cost = 0;
+  // U channel palette color cost.
+  int min_bits_u = 0;
+  const int bits_u = av1_get_palette_delta_bits_u(pmi, bit_depth, &min_bits_u);
+  cost += av1_cost_bit(128, 0) * (2 + bit_depth + bits_u * (n - 1));
+  // V channel palette color cost.
+  int zero_count = 0, min_bits_v = 0;
+  const int bits_v =
+      av1_get_palette_delta_bits_v(pmi, bit_depth, &zero_count, &min_bits_v);
+  const int bits_using_delta =
+      2 + bit_depth + (bits_v + 1) * (n - 1) - zero_count;
+  const int bits_using_raw = bit_depth * n;
+  cost += av1_cost_bit(128, 0) * (1 + AOMMIN(bits_using_delta, bits_using_raw));
+  return cost;
+#else
+  return 2 * bit_depth * n * av1_cost_bit(128, 0);
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+}
+
 #if CONFIG_HIGHBITDEPTH
 int av1_count_colors_highbd(const uint8_t *src8, int stride, int rows, int cols,
                             int bit_depth) {
diff --git a/av1/encoder/palette.h b/av1/encoder/palette.h
index 9898210..5403ac5 100644
--- a/av1/encoder/palette.h
+++ b/av1/encoder/palette.h
@@ -44,6 +44,28 @@
                             int bit_depth);
 #endif  // CONFIG_HIGHBITDEPTH
 
+#if CONFIG_PALETTE_DELTA_ENCODING
+// Return the number of bits used to transmit each luma palette color delta.
+int av1_get_palette_delta_bits_y(const PALETTE_MODE_INFO *const pmi,
+                                 int bit_depth, int *min_bits);
+
+// Return the number of bits used to transmit each U palette color delta.
+int av1_get_palette_delta_bits_u(const PALETTE_MODE_INFO *const pmi,
+                                 int bit_depth, int *min_bits);
+
+// Return the number of bits used to transmit each v palette color delta;
+// assign zero_count with the number of deltas being 0.
+int av1_get_palette_delta_bits_v(const PALETTE_MODE_INFO *const pmi,
+                                 int bit_depth, int *zero_count, int *min_bits);
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+
+// Return the rate cost for transmitting luma palette color values.
+int av1_palette_color_cost_y(const PALETTE_MODE_INFO *const pmi, int bit_depth);
+
+// Return the rate cost for transmitting chroma palette color values.
+int av1_palette_color_cost_uv(const PALETTE_MODE_INFO *const pmi,
+                              int bit_depth);
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 89161d9..07869a9 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2541,12 +2541,13 @@
       extend_palette_color_map(color_map, cols, rows, block_width,
                                block_height);
       palette_mode_cost =
-          dc_mode_cost + cpi->common.bit_depth * k * av1_cost_bit(128, 0) +
+          dc_mode_cost +
           cpi->palette_y_size_cost[bsize - BLOCK_8X8][k - PALETTE_MIN_SIZE] +
           write_uniform_cost(k, color_map[0]) +
           av1_cost_bit(
               av1_default_palette_y_mode_prob[bsize - BLOCK_8X8][palette_ctx],
               1);
+      palette_mode_cost += av1_palette_color_cost_y(pmi, cpi->common.bit_depth);
       for (i = 0; i < rows; ++i) {
         for (j = (i == 0 ? 1 : 0); j < cols; ++j) {
           int color_idx;
@@ -4585,6 +4586,22 @@
         centroids[i * 2 + 1] = lb_v + (2 * i + 1) * (ub_v - lb_v) / n / 2;
       }
       av1_k_means(data, centroids, color_map, rows * cols, n, 2, max_itr);
+#if CONFIG_PALETTE_DELTA_ENCODING
+      // Sort the U channel colors in ascending order.
+      for (i = 0; i < 2 * (n - 1); i += 2) {
+        int min_idx = i;
+        float min_val = centroids[i];
+        for (j = i + 2; j < 2 * n; j += 2)
+          if (centroids[j] < min_val) min_val = centroids[j], min_idx = j;
+        if (min_idx != i) {
+          float temp_u = centroids[i], temp_v = centroids[i + 1];
+          centroids[i] = centroids[min_idx];
+          centroids[i + 1] = centroids[min_idx + 1];
+          centroids[min_idx] = temp_u, centroids[min_idx + 1] = temp_v;
+        }
+      }
+      av1_calc_indices(data, centroids, color_map, rows * cols, n, 2);
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
       extend_palette_color_map(color_map, cols, rows, plane_block_width,
                                plane_block_height);
       pmi->palette_size[1] = n;
@@ -4605,12 +4622,11 @@
       if (tokenonly_rd_stats.rate == INT_MAX) continue;
       this_rate =
           tokenonly_rd_stats.rate + dc_mode_cost +
-          2 * cpi->common.bit_depth * n * av1_cost_bit(128, 0) +
           cpi->palette_uv_size_cost[bsize - BLOCK_8X8][n - PALETTE_MIN_SIZE] +
           write_uniform_cost(n, color_map[0]) +
           av1_cost_bit(
               av1_default_palette_uv_mode_prob[pmi->palette_size[0] > 0], 1);
-
+      this_rate += av1_palette_color_cost_uv(pmi, cpi->common.bit_depth);
       for (i = 0; i < rows; ++i) {
         for (j = (i == 0 ? 1 : 0); j < cols; ++j) {
           int color_idx;