Palette: use color cache to compress base colors

Get a list of palette base colors that are used in the above and
left blocks, referred to as "color cache". For each cache color,
signal if it is present in current block's palette, so that we
don't need to transmit their raw values.

When palette-delta-encoding is enabled, compression is improved
by 2% on keyframe and 1% overall for the screen_content testset.

Change-Id: I4cb027f1904aa9d0ab1c8f00ea9ee34bf5f16234
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index fb21bcf..52d95d7 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -1421,30 +1421,67 @@
 
 #if CONFIG_PALETTE
 #if CONFIG_PALETTE_DELTA_ENCODING
-// Write luma palette color values with delta encoding. Write the first value as
-// literal, and the deltas between each value and the previous one. The luma
-// palette is sorted so each delta is larger than 0.
-static void write_palette_colors_y(const PALETTE_MODE_INFO *const pmi,
-                                   int bit_depth, aom_writer *w) {
-  const int n = pmi->palette_size[0];
-  int min_bits, i;
-  int bits = av1_get_palette_delta_bits_y(pmi, bit_depth, &min_bits);
+// Transmit color values with delta encoding. Write the first value as
+// literal, and the deltas between each value and the previous one. "min_val" is
+// the smallest possible value of the deltas.
+static void delta_encode_palette_colors(const int *colors, int num,
+                                        int bit_depth, int min_val,
+                                        aom_writer *w) {
+  if (num <= 0) return;
+  aom_write_literal(w, colors[0], bit_depth);
+  if (num == 1) return;
+  int max_delta = 0;
+  int deltas[PALETTE_MAX_SIZE];
+  memset(deltas, 0, sizeof(deltas));
+  for (int i = 1; i < num; ++i) {
+    const int delta = colors[i] - colors[i - 1];
+    deltas[i - 1] = delta;
+    assert(delta >= min_val);
+    if (delta > max_delta) max_delta = delta;
+  }
+  const int min_bits = bit_depth - 3;
+  int bits = AOMMAX(av1_ceil_log2(max_delta + 1 - min_val), min_bits);
+  int range = (1 << bit_depth) - colors[0] - min_val;
   aom_write_literal(w, bits - min_bits, 2);
-  aom_write_literal(w, pmi->palette_colors[0], bit_depth);
-  for (i = 1; i < n; ++i) {
-    aom_write_literal(
-        w, pmi->palette_colors[i] - pmi->palette_colors[i - 1] - 1, bits);
-    bits =
-        AOMMIN(bits, av1_ceil_log2((1 << bit_depth) - pmi->palette_colors[i]));
+  for (int i = 0; i < num - 1; ++i) {
+    aom_write_literal(w, deltas[i] - min_val, bits);
+    range -= deltas[i];
+    bits = AOMMIN(bits, av1_ceil_log2(range));
   }
 }
 
-// Write chroma palette color values. Use delta encoding for u channel as its
-// palette is sorted. For v channel, either use delta encoding or transmit
-// raw values directly, whichever costs less.
-static void write_palette_colors_uv(const PALETTE_MODE_INFO *const pmi,
+// Transmit luma palette color values. First signal if each color in the color
+// cache is used. Those colors that are not in the cache are transmitted with
+// delta encoding.
+static void write_palette_colors_y(const MACROBLOCKD *const xd,
+                                   const PALETTE_MODE_INFO *const pmi,
+                                   int bit_depth, aom_writer *w) {
+  const int n = pmi->palette_size[0];
+  const MODE_INFO *const above_mi = xd->above_mi;
+  const MODE_INFO *const left_mi = xd->left_mi;
+  uint16_t color_cache[2 * PALETTE_MAX_SIZE];
+  const int n_cache = av1_get_palette_cache(above_mi, left_mi, 0, color_cache);
+  int out_cache_colors[PALETTE_MAX_SIZE];
+  uint8_t cache_color_found[2 * PALETTE_MAX_SIZE];
+  const int n_out_cache =
+      av1_index_color_cache(color_cache, n_cache, pmi->palette_colors, n,
+                            cache_color_found, out_cache_colors);
+  int n_in_cache = 0;
+  for (int i = 0; i < n_cache && n_in_cache < n; ++i) {
+    const int found = cache_color_found[i];
+    aom_write_bit(w, found);
+    n_in_cache += found;
+  }
+  assert(n_in_cache + n_out_cache == n);
+  delta_encode_palette_colors(out_cache_colors, n_out_cache, bit_depth, 1, w);
+}
+
+// Write chroma palette color values. U channel is handled similarly to the luma
+// channel. For v channel, either use delta encoding or transmit raw values
+// directly, whichever costs less.
+static void write_palette_colors_uv(const MACROBLOCKD *const xd,
+                                    const PALETTE_MODE_INFO *const pmi,
                                     int bit_depth, aom_writer *w) {
-  int i;
   const int n = pmi->palette_size[1];
 #if CONFIG_HIGHBITDEPTH
   const uint16_t *colors_u = pmi->palette_colors + PALETTE_MAX_SIZE;
@@ -1454,15 +1491,23 @@
   const uint8_t *colors_v = pmi->palette_colors + 2 * PALETTE_MAX_SIZE;
 #endif  // CONFIG_HIGHBITDEPTH
   // U channel colors.
-  int min_bits_u = 0;
-  int bits_u = av1_get_palette_delta_bits_u(pmi, bit_depth, &min_bits_u);
-  aom_write_literal(w, bits_u - min_bits_u, 2);
-  aom_write_literal(w, colors_u[0], bit_depth);
-  for (i = 1; i < n; ++i) {
-    aom_write_literal(w, colors_u[i] - colors_u[i - 1], bits_u);
-    bits_u = AOMMIN(bits_u, av1_ceil_log2(1 + (1 << bit_depth) - colors_u[i]));
+  const MODE_INFO *const above_mi = xd->above_mi;
+  const MODE_INFO *const left_mi = xd->left_mi;
+  uint16_t color_cache[2 * PALETTE_MAX_SIZE];
+  const int n_cache = av1_get_palette_cache(above_mi, left_mi, 1, color_cache);
+  int out_cache_colors[PALETTE_MAX_SIZE];
+  uint8_t cache_color_found[2 * PALETTE_MAX_SIZE];
+  const int n_out_cache = av1_index_color_cache(
+      color_cache, n_cache, colors_u, n, cache_color_found, out_cache_colors);
+  int n_in_cache = 0;
+  for (int i = 0; i < n_cache && n_in_cache < n; ++i) {
+    const int found = cache_color_found[i];
+    aom_write_bit(w, found);
+    n_in_cache += found;
   }
-  // V channel colors.
+  delta_encode_palette_colors(out_cache_colors, n_out_cache, bit_depth, 0, w);
+
+  // V channel colors. Don't use color cache as the colors are not sorted.
   const int max_val = 1 << bit_depth;
   int zero_count = 0, min_bits_v = 0;
   int bits_v =
@@ -1474,7 +1519,7 @@
     aom_write_bit(w, 1);
     aom_write_literal(w, bits_v - min_bits_v, 2);
     aom_write_literal(w, colors_v[0], bit_depth);
-    for (i = 1; i < n; ++i) {
+    for (int i = 1; i < n; ++i) {
       if (colors_v[i] == colors_v[i - 1]) {  // No need to signal sign bit.
         aom_write_literal(w, 0, bits_v);
         continue;
@@ -1491,7 +1536,7 @@
     }
   } else {  // Transmit raw values.
     aom_write_bit(w, 0);
-    for (i = 0; i < n; ++i) aom_write_literal(w, colors_v[i], bit_depth);
+    for (int i = 0; i < n; ++i) aom_write_literal(w, colors_v[i], bit_depth);
   }
 }
 #endif  // CONFIG_PALETTE_DELTA_ENCODING
@@ -1521,7 +1566,7 @@
                       av1_default_palette_y_size_prob[bsize - BLOCK_8X8],
                       &palette_size_encodings[n - PALETTE_MIN_SIZE]);
 #if CONFIG_PALETTE_DELTA_ENCODING
-      write_palette_colors_y(pmi, cm->bit_depth, w);
+      write_palette_colors_y(xd, pmi, cm->bit_depth, w);
 #else
       int i;
       for (i = 0; i < n; ++i)
@@ -1540,7 +1585,7 @@
                       av1_default_palette_uv_size_prob[bsize - BLOCK_8X8],
                       &palette_size_encodings[n - PALETTE_MIN_SIZE]);
 #if CONFIG_PALETTE_DELTA_ENCODING
-      write_palette_colors_uv(pmi, cm->bit_depth, w);
+      write_palette_colors_uv(xd, pmi, cm->bit_depth, w);
 #else
       int i;
       for (i = 0; i < n; ++i) {
diff --git a/av1/encoder/palette.c b/av1/encoder/palette.c
index 355141d..ca2a54f 100644
--- a/av1/encoder/palette.c
+++ b/av1/encoder/palette.c
@@ -167,31 +167,62 @@
 }
 
 #if CONFIG_PALETTE_DELTA_ENCODING
-int av1_get_palette_delta_bits_y(const PALETTE_MODE_INFO *const pmi,
-                                 int bit_depth, int *min_bits) {
-  const int n = pmi->palette_size[0];
-  int max_d = 0, i;
-  *min_bits = bit_depth - 3;
-  for (i = 1; i < n; ++i) {
-    const int delta = pmi->palette_colors[i] - pmi->palette_colors[i - 1];
-    assert(delta > 0);
-    if (delta > max_d) max_d = delta;
+static int delta_encode_cost(const int *colors, int num, int bit_depth,
+                             int min_val) {
+  if (num <= 0) return 0;
+  int bits_cost = bit_depth;
+  if (num == 1) return bits_cost;
+  bits_cost += 2;
+  int max_delta = 0;
+  int deltas[PALETTE_MAX_SIZE];
+  const int min_bits = bit_depth - 3;
+  for (int i = 1; i < num; ++i) {
+    const int delta = colors[i] - colors[i - 1];
+    deltas[i - 1] = delta;
+    assert(delta >= min_val);
+    if (delta > max_delta) max_delta = delta;
   }
-  return AOMMAX(av1_ceil_log2(max_d), *min_bits);
+  int bits_per_delta = AOMMAX(av1_ceil_log2(max_delta + 1 - min_val), min_bits);
+  int range = (1 << bit_depth) - colors[0] - min_val;
+  for (int i = 0; i < num - 1; ++i) {
+    bits_cost += bits_per_delta;
+    range -= deltas[i];
+    bits_per_delta = AOMMIN(bits_per_delta, av1_ceil_log2(range));
+  }
+  return bits_cost;
 }
 
-int av1_get_palette_delta_bits_u(const PALETTE_MODE_INFO *const pmi,
-                                 int bit_depth, int *min_bits) {
-  const int n = pmi->palette_size[1];
-  int max_d = 0, i;
-  *min_bits = bit_depth - 3;
-  for (i = 1; i < n; ++i) {
-    const int delta = pmi->palette_colors[PALETTE_MAX_SIZE + i] -
-                      pmi->palette_colors[PALETTE_MAX_SIZE + i - 1];
-    assert(delta >= 0);
-    if (delta > max_d) max_d = delta;
+int av1_index_color_cache(uint16_t *color_cache, int n_cache,
+                          const void *colors, int n_colors,
+                          uint8_t *cache_color_found, int *out_cache_colors) {
+#if CONFIG_HIGHBITDEPTH
+  const uint16_t *colors_in = (const uint16 *)colors;
+#else
+  const uint8_t *colors_in = (const uint8_t *)colors;
+#endif  // CONFIG_HIGHBITDEPTH
+  if (n_cache <= 0) {
+    for (int i = 0; i < n_colors; ++i) out_cache_colors[i] = colors_in[i];
+    return n_colors;
   }
-  return AOMMAX(av1_ceil_log2(max_d + 1), *min_bits);
+  memset(cache_color_found, 0, n_cache * sizeof(*cache_color_found));
+  int n_in_cache = 0;
+  int in_cache_flags[PALETTE_MAX_SIZE];
+  memset(in_cache_flags, 0, sizeof(in_cache_flags));
+  for (int i = 0; i < n_cache && n_in_cache < n_colors; ++i) {
+    for (int j = 0; j < n_colors; ++j) {
+      if (colors_in[j] == color_cache[i]) {
+        in_cache_flags[j] = 1;
+        cache_color_found[i] = 1;
+        ++n_in_cache;
+        break;
+      }
+    }
+  }
+  int j = 0;
+  for (int i = 0; i < n_colors; ++i)
+    if (!in_cache_flags[i]) out_cache_colors[j++] = colors_in[i];
+  assert(j == n_colors - n_in_cache);
+  return j;
 }
 
 int av1_get_palette_delta_bits_v(const PALETTE_MODE_INFO *const pmi,
@@ -199,10 +230,10 @@
                                  int *min_bits) {
   const int n = pmi->palette_size[1];
   const int max_val = 1 << bit_depth;
-  int max_d = 0, i;
+  int max_d = 0;
   *min_bits = bit_depth - 4;
   *zero_count = 0;
-  for (i = 1; i < n; ++i) {
+  for (int i = 1; i < n; ++i) {
     const int delta = pmi->palette_colors[2 * PALETTE_MAX_SIZE + i] -
                       pmi->palette_colors[2 * PALETTE_MAX_SIZE + i - 1];
     const int v = abs(delta);
@@ -215,26 +246,42 @@
 #endif  // CONFIG_PALETTE_DELTA_ENCODING
 
 int av1_palette_color_cost_y(const PALETTE_MODE_INFO *const pmi,
+#if CONFIG_PALETTE_DELTA_ENCODING
+                             uint16_t *color_cache, int n_cache,
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
                              int bit_depth) {
   const int n = pmi->palette_size[0];
 #if CONFIG_PALETTE_DELTA_ENCODING
-  int min_bits = 0;
-  const int bits = av1_get_palette_delta_bits_y(pmi, bit_depth, &min_bits);
-  return av1_cost_bit(128, 0) * (2 + bit_depth + bits * (n - 1));
+  int out_cache_colors[PALETTE_MAX_SIZE];
+  uint8_t cache_color_found[2 * PALETTE_MAX_SIZE];
+  const int n_out_cache =
+      av1_index_color_cache(color_cache, n_cache, pmi->palette_colors, n,
+                            cache_color_found, out_cache_colors);
+  const int total_bits =
+      n_cache + delta_encode_cost(out_cache_colors, n_out_cache, bit_depth, 1);
+  return total_bits * av1_cost_bit(128, 0);
 #else
   return bit_depth * n * av1_cost_bit(128, 0);
 #endif  // CONFIG_PALETTE_DELTA_ENCODING
 }
 
 int av1_palette_color_cost_uv(const PALETTE_MODE_INFO *const pmi,
+#if CONFIG_PALETTE_DELTA_ENCODING
+                              uint16_t *color_cache, int n_cache,
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
                               int bit_depth) {
   const int n = pmi->palette_size[1];
 #if CONFIG_PALETTE_DELTA_ENCODING
-  int cost = 0;
+  int total_bits = 0;
   // U channel palette color cost.
-  int min_bits_u = 0;
-  const int bits_u = av1_get_palette_delta_bits_u(pmi, bit_depth, &min_bits_u);
-  cost += av1_cost_bit(128, 0) * (2 + bit_depth + bits_u * (n - 1));
+  int out_cache_colors[PALETTE_MAX_SIZE];
+  uint8_t cache_color_found[2 * PALETTE_MAX_SIZE];
+  const int n_out_cache = av1_index_color_cache(
+      color_cache, n_cache, pmi->palette_colors + PALETTE_MAX_SIZE, n,
+      cache_color_found, out_cache_colors);
+  total_bits +=
+      n_cache + delta_encode_cost(out_cache_colors, n_out_cache, bit_depth, 0);
+
   // V channel palette color cost.
   int zero_count = 0, min_bits_v = 0;
   const int bits_v =
@@ -242,8 +289,8 @@
   const int bits_using_delta =
       2 + bit_depth + (bits_v + 1) * (n - 1) - zero_count;
   const int bits_using_raw = bit_depth * n;
-  cost += av1_cost_bit(128, 0) * (1 + AOMMIN(bits_using_delta, bits_using_raw));
-  return cost;
+  total_bits += 1 + AOMMIN(bits_using_delta, bits_using_raw);
+  return total_bits * av1_cost_bit(128, 0);
 #else
   return 2 * bit_depth * n * av1_cost_bit(128, 0);
 #endif  // CONFIG_PALETTE_DELTA_ENCODING
diff --git a/av1/encoder/palette.h b/av1/encoder/palette.h
index 5403ac5..7bd51c7 100644
--- a/av1/encoder/palette.h
+++ b/av1/encoder/palette.h
@@ -45,13 +45,12 @@
 #endif  // CONFIG_HIGHBITDEPTH
 
 #if CONFIG_PALETTE_DELTA_ENCODING
-// Return the number of bits used to transmit each luma palette color delta.
-int av1_get_palette_delta_bits_y(const PALETTE_MODE_INFO *const pmi,
-                                 int bit_depth, int *min_bits);
-
-// Return the number of bits used to transmit each U palette color delta.
-int av1_get_palette_delta_bits_u(const PALETTE_MODE_INFO *const pmi,
-                                 int bit_depth, int *min_bits);
+// Given a color cache and a set of base colors, find if each cache color is
+// present in the base colors, record the binary results in "cache_color_found".
+// Record the colors that are not in the color cache in "out_cache_colors".
+int av1_index_color_cache(uint16_t *color_cache, int n_cache,
+                          const void *colors, int n_colors,
+                          uint8_t *cache_color_found, int *out_cache_colors);
 
 // Return the number of bits used to transmit each v palette color delta;
 // assign zero_count with the number of deltas being 0.
@@ -60,10 +59,17 @@
 #endif  // CONFIG_PALETTE_DELTA_ENCODING
 
 // Return the rate cost for transmitting luma palette color values.
-int av1_palette_color_cost_y(const PALETTE_MODE_INFO *const pmi, int bit_depth);
+int av1_palette_color_cost_y(const PALETTE_MODE_INFO *const pmi,
+#if CONFIG_PALETTE_DELTA_ENCODING
+                             uint16_t *color_cache, int n_cache,
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+                             int bit_depth);
 
 // Return the rate cost for transmitting chroma palette color values.
 int av1_palette_color_cost_uv(const PALETTE_MODE_INFO *const pmi,
+#if CONFIG_PALETTE_DELTA_ENCODING
+                              uint16_t *color_cache, int n_cache,
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
                               int bit_depth);
 
 #ifdef __cplusplus
diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c
index 6422842..61c3404 100644
--- a/av1/encoder/rdopt.c
+++ b/av1/encoder/rdopt.c
@@ -2429,6 +2429,28 @@
   }
 }
 
+#if CONFIG_PALETTE_DELTA_ENCODING
+// Bias toward using colors in the cache.
+// TODO(huisu): Try other schemes to improve compression.
+static void optimize_palette_colors(uint16_t *color_cache, int n_cache,
+                                    int n_colors, int stride,
+                                    float *centroids) {
+  if (n_cache <= 0) return;
+  for (int i = 0; i < n_colors * stride; i += stride) {
+    float min_diff = fabsf(centroids[i] - color_cache[0]);
+    int idx = 0;
+    for (int j = 1; j < n_cache; ++j) {
+      float this_diff = fabsf(centroids[i] - color_cache[j]);
+      if (this_diff < min_diff) {
+        min_diff = this_diff;
+        idx = j;
+      }
+    }
+    if (min_diff < 1.5) centroids[i] = color_cache[idx];
+  }
+}
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+
 static int rd_pick_palette_intra_sby(const AV1_COMP *const cpi, MACROBLOCK *x,
                                      BLOCK_SIZE bsize, int palette_ctx,
                                      int dc_mode_cost, MB_MODE_INFO *best_mbmi,
@@ -2515,6 +2537,14 @@
 
     if (rows * cols > PALETTE_MAX_BLOCK_SIZE) return 0;
 
+#if CONFIG_PALETTE_DELTA_ENCODING
+    const MODE_INFO *above_mi = xd->above_mi;
+    const MODE_INFO *left_mi = xd->left_mi;
+    uint16_t color_cache[2 * PALETTE_MAX_SIZE];
+    const int n_cache =
+        av1_get_palette_cache(above_mi, left_mi, 0, color_cache);
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+
     for (n = colors > PALETTE_MAX_SIZE ? PALETTE_MAX_SIZE : colors; n >= 2;
          --n) {
       if (colors == PALETTE_MIN_SIZE) {
@@ -2529,6 +2559,9 @@
           centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2;
         }
         av1_k_means(data, centroids, color_map, rows * cols, n, 1, max_itr);
+#if CONFIG_PALETTE_DELTA_ENCODING
+        optimize_palette_colors(color_cache, n_cache, n, 1, centroids);
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
         k = av1_remove_duplicates(centroids, n);
         if (k < PALETTE_MIN_SIZE) {
           // Too few unique colors to create a palette. And DC_PRED will work
@@ -2558,7 +2591,11 @@
           av1_cost_bit(
               av1_default_palette_y_mode_prob[bsize - BLOCK_8X8][palette_ctx],
               1);
-      palette_mode_cost += av1_palette_color_cost_y(pmi, cpi->common.bit_depth);
+      palette_mode_cost += av1_palette_color_cost_y(pmi,
+#if CONFIG_PALETTE_DELTA_ENCODING
+                                                    color_cache, n_cache,
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+                                                    cpi->common.bit_depth);
       for (i = 0; i < rows; ++i) {
         for (j = (i == 0 ? 1 : 0); j < cols; ++j) {
           int color_idx;
@@ -4517,6 +4554,13 @@
   }
 #endif  // CONFIG_HIGHBITDEPTH
 
+#if CONFIG_PALETTE_DELTA_ENCODING
+  const MODE_INFO *above_mi = xd->above_mi;
+  const MODE_INFO *left_mi = xd->left_mi;
+  uint16_t color_cache[2 * PALETTE_MAX_SIZE];
+  const int n_cache = av1_get_palette_cache(above_mi, left_mi, 1, color_cache);
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+
   colors = colors_u > colors_v ? colors_u : colors_v;
   if (colors > 1 && colors <= 64) {
     int r, c, n, i, j;
@@ -4581,6 +4625,7 @@
       }
       av1_k_means(data, centroids, color_map, rows * cols, n, 2, max_itr);
 #if CONFIG_PALETTE_DELTA_ENCODING
+      optimize_palette_colors(color_cache, n_cache, n, 2, centroids);
       // Sort the U channel colors in ascending order.
       for (i = 0; i < 2 * (n - 1); i += 2) {
         int min_idx = i;
@@ -4620,7 +4665,11 @@
           write_uniform_cost(n, color_map[0]) +
           av1_cost_bit(
               av1_default_palette_uv_mode_prob[pmi->palette_size[0] > 0], 1);
-      this_rate += av1_palette_color_cost_uv(pmi, cpi->common.bit_depth);
+      this_rate += av1_palette_color_cost_uv(pmi,
+#if CONFIG_PALETTE_DELTA_ENCODING
+                                             color_cache, n_cache,
+#endif  // CONFIG_PALETTE_DELTA_ENCODING
+                                             cpi->common.bit_depth);
       for (i = 0; i < rows; ++i) {
         for (j = (i == 0 ? 1 : 0); j < cols; ++j) {
           int color_idx;