Add cdf optimizer for a few symbols with variable # of modes

Change-Id: Ifa561062fa49d1a537258a93046f63292cda4058
diff --git a/av1/common/entropymode.h b/av1/common/entropymode.h
index 110531c..950898f 100644
--- a/av1/common/entropymode.h
+++ b/av1/common/entropymode.h
@@ -255,7 +255,7 @@
 #endif  // CONFIG_ENTROPY_STATS
   unsigned int filter_intra_mode[FILTER_INTRA_MODES];
   unsigned int filter_intra[BLOCK_SIZES_ALL][2];
-  unsigned int switchable_restore[3];
+  unsigned int switchable_restore[RESTORE_SWITCHABLE_TYPES];
   unsigned int wiener_restore[2];
   unsigned int sgrproj_restore[2];
 } FRAME_COUNTS;
diff --git a/tools/aom_entropy_optimizer.c b/tools/aom_entropy_optimizer.c
index 933ccc7..09d94eb 100644
--- a/tools/aom_entropy_optimizer.c
+++ b/tools/aom_entropy_optimizer.c
@@ -43,6 +43,9 @@
   csum[0] = counts[0] + 1;
   for (int i = 1; i < modes; ++i) csum[i] = counts[i] + 1 + csum[i - 1];
 
+  for (int i = 0; i < modes; ++i) fprintf(logfile, "%d ", counts[i]);
+  fprintf(logfile, "\n");
+
   int64_t sum = csum[modes - 1];
   const int64_t round_shift = sum >> 1;
   for (int i = 0; i < modes; ++i) {
@@ -69,18 +72,18 @@
     (*ct_ptr) += total_modes;
 
     if (tabs > 0) fprintf(probsfile, "%*c", tabs * SPACES_PER_TAB, ' ');
-    fprintf(probsfile, "AOM_CDF%d( ", total_modes);
+    fprintf(probsfile, "AOM_CDF%d(", total_modes);
     for (int k = 0; k < total_modes - 1; ++k) {
       fprintf(probsfile, "%d", cdfs[k]);
-      if (k < total_modes - 2) fprintf(probsfile, ",");
+      if (k < total_modes - 2) fprintf(probsfile, ", ");
     }
-    fprintf(probsfile, " )");
+    fprintf(probsfile, ")");
   } else {
     for (int k = 0; k < total_modes; ++k) {
       int tabs_next_level;
 
       if (dim_of_cts == 2)
-        fprintf(probsfile, "%*c{", tabs * SPACES_PER_TAB, ' ');
+        fprintf(probsfile, "%*c{ ", tabs * SPACES_PER_TAB, ' ');
       else
         fprintf(probsfile, "%*c{\n", tabs * SPACES_PER_TAB, ' ');
       tabs_next_level = dim_of_cts == 2 ? 0 : tabs + 1;
@@ -92,9 +95,9 @@
 
       if (dim_of_cts == 2) {
         if (k == total_modes - 1)
-          fprintf(probsfile, "}\n");
+          fprintf(probsfile, " }\n");
         else
-          fprintf(probsfile, "},\n");
+          fprintf(probsfile, " },\n");
       } else {
         if (k == total_modes - 1)
           fprintf(probsfile, "%*c}\n", tabs * SPACES_PER_TAB, ' ');
@@ -112,11 +115,13 @@
   aom_count_type *ct_ptr = counts;
 
   fprintf(probsfile, "%s = {\n", prefix);
+  fprintf(logfile, "%s\n", prefix);
   if (parse_counts_for_cdf_opt(&ct_ptr, probsfile, 1, dim_of_cts,
                                cts_each_dim)) {
     fprintf(probsfile, "Optimizer failed!\n");
   }
   fprintf(probsfile, "};\n\n");
+  fprintf(logfile, "============================\n");
 }
 
 static void optimize_uv_mode(aom_count_type *counts, FILE *const probsfile,
@@ -125,15 +130,16 @@
 
   fprintf(probsfile, "%s = {\n", prefix);
   fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' ');
+  fprintf(logfile, "%s\n", prefix);
   cts_each_dim[2] = UV_INTRA_MODES - 1;
   for (int k = 0; k < cts_each_dim[1]; ++k) {
-    fprintf(probsfile, "%*c{", 2 * SPACES_PER_TAB, ' ');
+    fprintf(probsfile, "%*c{ ", 2 * SPACES_PER_TAB, ' ');
     parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, dim_of_cts - 2,
                              cts_each_dim + 2);
     if (k + 1 == cts_each_dim[1]) {
-      fprintf(probsfile, "}\n");
+      fprintf(probsfile, " }\n");
     } else {
-      fprintf(probsfile, "},\n");
+      fprintf(probsfile, " },\n");
     }
     ++ct_ptr;
   }
@@ -144,6 +150,109 @@
                            cts_each_dim + 1);
   fprintf(probsfile, "%*c}\n", SPACES_PER_TAB, ' ');
   fprintf(probsfile, "};\n\n");
+  fprintf(logfile, "============================\n");
+}
+
+static void optimize_cdf_table_var_modes_2d(aom_count_type *counts,
+                                            FILE *const probsfile,
+                                            int dim_of_cts, int *cts_each_dim,
+                                            int *modes_each_ctx, char *prefix) {
+  aom_count_type *ct_ptr = counts;
+
+  assert(dim_of_cts == 2);
+  (void)dim_of_cts;
+
+  fprintf(probsfile, "%s = {\n", prefix);
+  fprintf(logfile, "%s\n", prefix);
+
+  for (int d0_idx = 0; d0_idx < cts_each_dim[0]; ++d0_idx) {
+    int num_of_modes = modes_each_ctx[d0_idx];
+
+    if (num_of_modes > 0) {
+      fprintf(probsfile, "%*c{ ", SPACES_PER_TAB, ' ');
+      parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, 1, &num_of_modes);
+      ct_ptr += cts_each_dim[1] - num_of_modes;
+      fprintf(probsfile, " },\n");
+    } else {
+      fprintf(probsfile, "%*c{ 0 },\n", SPACES_PER_TAB, ' ');
+      fprintf(logfile, "dummy cdf, no need to optimize\n");
+      ct_ptr += cts_each_dim[1];
+    }
+  }
+  fprintf(probsfile, "};\n\n");
+  fprintf(logfile, "============================\n");
+}
+
+static void optimize_cdf_table_var_modes_3d(aom_count_type *counts,
+                                            FILE *const probsfile,
+                                            int dim_of_cts, int *cts_each_dim,
+                                            int *modes_each_ctx, char *prefix) {
+  aom_count_type *ct_ptr = counts;
+
+  assert(dim_of_cts == 3);
+  (void)dim_of_cts;
+
+  fprintf(probsfile, "%s = {\n", prefix);
+  fprintf(logfile, "%s\n", prefix);
+
+  for (int d0_idx = 0; d0_idx < cts_each_dim[0]; ++d0_idx) {
+    fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' ');
+    for (int d1_idx = 0; d1_idx < cts_each_dim[1]; ++d1_idx) {
+      int num_of_modes = modes_each_ctx[d0_idx];
+
+      if (num_of_modes > 0) {
+        fprintf(probsfile, "%*c{ ", 2 * SPACES_PER_TAB, ' ');
+        parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, 1, &num_of_modes);
+        ct_ptr += cts_each_dim[2] - num_of_modes;
+        fprintf(probsfile, " },\n");
+      } else {
+        fprintf(probsfile, "%*c{ 0 },\n", 2 * SPACES_PER_TAB, ' ');
+        fprintf(logfile, "dummy cdf, no need to optimize\n");
+        ct_ptr += cts_each_dim[2];
+      }
+    }
+    fprintf(probsfile, "%*c},\n", SPACES_PER_TAB, ' ');
+  }
+  fprintf(probsfile, "};\n\n");
+  fprintf(logfile, "============================\n");
+}
+
+static void optimize_cdf_table_var_modes_4d(aom_count_type *counts,
+                                            FILE *const probsfile,
+                                            int dim_of_cts, int *cts_each_dim,
+                                            int *modes_each_ctx, char *prefix) {
+  aom_count_type *ct_ptr = counts;
+
+  assert(dim_of_cts == 4);
+  (void)dim_of_cts;
+
+  fprintf(probsfile, "%s = {\n", prefix);
+  fprintf(logfile, "%s\n", prefix);
+
+  for (int d0_idx = 0; d0_idx < cts_each_dim[0]; ++d0_idx) {
+    fprintf(probsfile, "%*c{\n", SPACES_PER_TAB, ' ');
+    for (int d1_idx = 0; d1_idx < cts_each_dim[1]; ++d1_idx) {
+      fprintf(probsfile, "%*c{\n", 2 * SPACES_PER_TAB, ' ');
+      for (int d2_idx = 0; d2_idx < cts_each_dim[2]; ++d2_idx) {
+        int num_of_modes = modes_each_ctx[d0_idx];
+
+        if (num_of_modes > 0) {
+          fprintf(probsfile, "%*c{ ", 3 * SPACES_PER_TAB, ' ');
+          parse_counts_for_cdf_opt(&ct_ptr, probsfile, 0, 1, &num_of_modes);
+          ct_ptr += cts_each_dim[3] - num_of_modes;
+          fprintf(probsfile, " },\n");
+        } else {
+          fprintf(probsfile, "%*c{ 0 },\n", 3 * SPACES_PER_TAB, ' ');
+          fprintf(logfile, "dummy cdf, no need to optimize\n");
+          ct_ptr += cts_each_dim[3];
+        }
+      }
+      fprintf(probsfile, "%*c},\n", 2 * SPACES_PER_TAB, ' ');
+    }
+    fprintf(probsfile, "%*c},\n", SPACES_PER_TAB, ' ');
+  }
+  fprintf(probsfile, "};\n\n");
+  fprintf(logfile, "============================\n");
 }
 
 int main(int argc, const char **argv) {
@@ -189,8 +298,7 @@
   cts_each_dim[0] = DIRECTIONAL_MODES;
   cts_each_dim[1] = 2 * MAX_ANGLE_DELTA + 1;
   optimize_cdf_table(&fc.angle_delta[0][0], probsfile, 2, cts_each_dim,
-                     "const aom_cdf_prob\n"
-                     "default_angle_delta_cdf"
+                     "static const aom_cdf_prob default_angle_delta_cdf"
                      "[DIRECTIONAL_MODES][CDF_SIZE(2 * MAX_ANGLE_DELTA + 1)]");
 
   /* Intra mode (non-keyframe luma) */
@@ -210,6 +318,39 @@
                    "default_uv_mode_cdf[CFL_ALLOWED_TYPES][INTRA_MODES]"
                    "[CDF_SIZE(UV_INTRA_MODES)]");
 
+  /* block partition */
+  cts_each_dim[0] = PARTITION_CONTEXTS;
+  cts_each_dim[1] = EXT_PARTITION_TYPES;
+  int part_types_each_ctx[PARTITION_CONTEXTS] = {
+    4, 4, 4, 4, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 8, 8, 8, 8
+  };
+  optimize_cdf_table_var_modes_2d(
+      &fc.partition[0][0], probsfile, 2, cts_each_dim, part_types_each_ctx,
+      "static const aom_cdf_prob default_partition_cdf[PARTITION_CONTEXTS]"
+      "[CDF_SIZE(EXT_PARTITION_TYPES)]");
+
+  /* tx type */
+  cts_each_dim[0] = EXT_TX_SETS_INTRA;
+  cts_each_dim[1] = EXT_TX_SIZES;
+  cts_each_dim[2] = INTRA_MODES;
+  cts_each_dim[3] = TX_TYPES;
+  int intra_ext_tx_types_each_ctx[EXT_TX_SETS_INTRA] = { 0, 7, 5 };
+  optimize_cdf_table_var_modes_4d(
+      &fc.intra_ext_tx[0][0][0][0], probsfile, 4, cts_each_dim,
+      intra_ext_tx_types_each_ctx,
+      "static const aom_cdf_prob default_intra_ext_tx_cdf[EXT_TX_SETS_INTRA]"
+      "[EXT_TX_SIZES][INTRA_MODES][CDF_SIZE(TX_TYPES)]");
+
+  cts_each_dim[0] = EXT_TX_SETS_INTER;
+  cts_each_dim[1] = EXT_TX_SIZES;
+  cts_each_dim[2] = TX_TYPES;
+  int inter_ext_tx_types_each_ctx[EXT_TX_SETS_INTER] = { 0, 16, 12, 2 };
+  optimize_cdf_table_var_modes_3d(
+      &fc.inter_ext_tx[0][0][0], probsfile, 3, cts_each_dim,
+      inter_ext_tx_types_each_ctx,
+      "static const aom_cdf_prob default_inter_ext_tx_cdf[EXT_TX_SETS_INTER]"
+      "[EXT_TX_SIZES][CDF_SIZE(TX_TYPES)]");
+
   /* Chroma from Luma */
   cts_each_dim[0] = CFL_JOINT_SIGNS;
   optimize_cdf_table(&fc.cfl_sign[0], probsfile, 1, cts_each_dim,
@@ -222,14 +363,6 @@
                      "default_cfl_alpha_cdf[CFL_ALPHA_CONTEXTS]"
                      "[CDF_SIZE(CFL_ALPHABET_SIZE)]");
 
-  /* Partition */
-  cts_each_dim[0] = PARTITION_CONTEXTS;
-  cts_each_dim[1] = EXT_PARTITION_TYPES;
-  optimize_cdf_table(&fc.partition[0][0], probsfile, 2, cts_each_dim,
-                     "static const aom_cdf_prob\n"
-                     "default_partition_cdf[PARTITION_CONTEXTS][CDF_SIZE(EXT_"
-                     "PARTITION_TYPES)]");
-
   /* Interpolation filter */
   cts_each_dim[0] = SWITCHABLE_FILTER_CONTEXTS;
   cts_each_dim[1] = SWITCHABLE_FILTERS;
@@ -377,6 +510,52 @@
       "static const aom_cdf_prob\n"
       "default_comp_bwdref_cdf[REF_CONTEXTS][BWD_REFS - 1][CDF_SIZE(2)]");
 
+  /* palette */
+  cts_each_dim[0] = PALATTE_BSIZE_CTXS;
+  cts_each_dim[1] = PALETTE_SIZES;
+  optimize_cdf_table(&fc.palette_y_size[0][0], probsfile, 2, cts_each_dim,
+                     "const aom_cdf_prob default_palette_y_size_cdf"
+                     "[PALATTE_BSIZE_CTXS][CDF_SIZE(PALETTE_SIZES)]");
+
+  cts_each_dim[0] = PALATTE_BSIZE_CTXS;
+  cts_each_dim[1] = PALETTE_SIZES;
+  optimize_cdf_table(&fc.palette_uv_size[0][0], probsfile, 2, cts_each_dim,
+                     "const aom_cdf_prob default_palette_uv_size_cdf"
+                     "[PALATTE_BSIZE_CTXS][CDF_SIZE(PALETTE_SIZES)]");
+
+  cts_each_dim[0] = PALATTE_BSIZE_CTXS;
+  cts_each_dim[1] = PALETTE_Y_MODE_CONTEXTS;
+  cts_each_dim[2] = 2;
+  optimize_cdf_table(&fc.palette_y_mode[0][0][0], probsfile, 3, cts_each_dim,
+                     "const aom_cdf_prob default_palette_y_mode_cdf"
+                     "[PALATTE_BSIZE_CTXS][PALETTE_Y_MODE_CONTEXTS]"
+                     "[CDF_SIZE(2)]");
+
+  cts_each_dim[0] = PALETTE_UV_MODE_CONTEXTS;
+  cts_each_dim[1] = 2;
+  optimize_cdf_table(&fc.palette_uv_mode[0][0], probsfile, 2, cts_each_dim,
+                     "const aom_cdf_prob default_palette_uv_mode_cdf"
+                     "[PALETTE_UV_MODE_CONTEXTS][CDF_SIZE(2)]");
+
+  cts_each_dim[0] = PALETTE_SIZES;
+  cts_each_dim[1] = PALETTE_COLOR_INDEX_CONTEXTS;
+  cts_each_dim[2] = PALETTE_COLORS;
+  int palette_color_indexes_each_ctx[PALETTE_SIZES] = { 2, 3, 4, 5, 6, 7, 8 };
+  optimize_cdf_table_var_modes_3d(
+      &fc.palette_y_color_index[0][0][0], probsfile, 3, cts_each_dim,
+      palette_color_indexes_each_ctx,
+      "const aom_cdf_prob default_palette_y_color_index_cdf[PALETTE_SIZES]"
+      "[PALETTE_COLOR_INDEX_CONTEXTS][CDF_SIZE(PALETTE_COLORS)]");
+
+  cts_each_dim[0] = PALETTE_SIZES;
+  cts_each_dim[1] = PALETTE_COLOR_INDEX_CONTEXTS;
+  cts_each_dim[2] = PALETTE_COLORS;
+  optimize_cdf_table_var_modes_3d(
+      &fc.palette_uv_color_index[0][0][0], probsfile, 3, cts_each_dim,
+      palette_color_indexes_each_ctx,
+      "const aom_cdf_prob default_palette_uv_color_index_cdf[PALETTE_SIZES]"
+      "[PALETTE_COLOR_INDEX_CONTEXTS][CDF_SIZE(PALETTE_COLORS)]");
+
   /* Transform size */
   cts_each_dim[0] = TXFM_PARTITION_CONTEXTS;
   cts_each_dim[1] = 2;
@@ -392,6 +571,26 @@
                      "static const aom_cdf_prob "
                      "default_skip_cdfs[SKIP_CONTEXTS][CDF_SIZE(2)]");
 
+  /* Skip mode flag */
+  cts_each_dim[0] = SKIP_MODE_CONTEXTS;
+  cts_each_dim[1] = 2;
+  optimize_cdf_table(&fc.skip_mode[0][0], probsfile, 2, cts_each_dim,
+                     "static const aom_cdf_prob "
+                     "default_skip_mode_cdfs[SKIP_MODE_CONTEXTS][CDF_SIZE(2)]");
+
+  /* joint compound flag */
+  cts_each_dim[0] = COMP_INDEX_CONTEXTS;
+  cts_each_dim[1] = 2;
+  optimize_cdf_table(&fc.compound_index[0][0], probsfile, 2, cts_each_dim,
+                     "static const aom_cdf_prob default_compound_idx_cdfs"
+                     "[COMP_INDEX_CONTEXTS][CDF_SIZE(2)]");
+
+  cts_each_dim[0] = COMP_GROUP_IDX_CONTEXTS;
+  cts_each_dim[1] = 2;
+  optimize_cdf_table(&fc.comp_group_idx[0][0], probsfile, 2, cts_each_dim,
+                     "static const aom_cdf_prob default_comp_group_idx_cdfs"
+                     "[COMP_GROUP_IDX_CONTEXTS][CDF_SIZE(2)]");
+
   /* intrabc */
   cts_each_dim[0] = 2;
   optimize_cdf_table(
@@ -411,6 +610,33 @@
                      "static const aom_cdf_prob "
                      "default_filter_intra_cdfs[BLOCK_SIZES_ALL][CDF_SIZE(2)]");
 
+  /* restoration type */
+  cts_each_dim[0] = RESTORE_SWITCHABLE_TYPES;
+  optimize_cdf_table(&fc.switchable_restore[0], probsfile, 1, cts_each_dim,
+                     "static const aom_cdf_prob default_switchable_restore_cdf"
+                     "[CDF_SIZE(RESTORE_SWITCHABLE_TYPES)]");
+
+  cts_each_dim[0] = 2;
+  optimize_cdf_table(&fc.wiener_restore[0], probsfile, 1, cts_each_dim,
+                     "static const aom_cdf_prob default_wiener_restore_cdf"
+                     "[CDF_SIZE(2)]");
+
+  cts_each_dim[0] = 2;
+  optimize_cdf_table(&fc.sgrproj_restore[0], probsfile, 1, cts_each_dim,
+                     "static const aom_cdf_prob default_sgrproj_restore_cdf"
+                     "[CDF_SIZE(2)]");
+
+  /* intra tx size */
+  cts_each_dim[0] = MAX_TX_CATS;
+  cts_each_dim[1] = TX_SIZE_CONTEXTS;
+  cts_each_dim[2] = MAX_TX_DEPTH + 1;
+  int intra_tx_sizes_each_ctx[MAX_TX_CATS] = { 2, 3, 3, 3 };
+  optimize_cdf_table_var_modes_3d(
+      &fc.intra_tx_size[0][0][0], probsfile, 3, cts_each_dim,
+      intra_tx_sizes_each_ctx,
+      "static const aom_cdf_prob default_tx_size_cdf"
+      "[MAX_TX_CATS][TX_SIZE_CONTEXTS][CDF_SIZE(MAX_TX_DEPTH + 1)]");
+
   /* transform coding */
   cts_each_dim[0] = TOKEN_CDF_Q_CTXS;
   cts_each_dim[1] = TX_SIZES;
@@ -522,13 +748,6 @@
       "[TOKEN_CDF_Q_CTXS][TX_SIZES][PLANE_TYPES][SIG_COEF_CONTEXTS_EOB]"
       "[CDF_SIZE(NUM_BASE_LEVELS + 1)]");
 
-  /* Skip mode flag */
-  cts_each_dim[0] = SKIP_MODE_CONTEXTS;
-  cts_each_dim[1] = 2;
-  optimize_cdf_table(&fc.skip_mode[0][0], probsfile, 2, cts_each_dim,
-                     "static const aom_cdf_prob "
-                     "default_skip_mode_cdfs[SKIP_MODE_CONTEXTS][CDF_SIZE(2)]");
-
   fclose(statsfile);
   fclose(logfile);
   fclose(probsfile);