Add a control to set the ANS window size

Change-Id: I3d64ec4bbc72143b30a094ece7a6c711d6b479cd
diff --git a/av1/av1_cx_iface.c b/av1/av1_cx_iface.c
index 1286dda..48f3a6d 100644
--- a/av1/av1_cx_iface.c
+++ b/av1/av1_cx_iface.c
@@ -66,6 +66,9 @@
   int render_width;
   int render_height;
   aom_superblock_size_t superblock_size;
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+  int ans_window_size_log2;
+#endif
 };
 
 static struct av1_extracfg default_extra_cfg = {
@@ -106,16 +109,19 @@
   1,  // max number of tile groups
   0,  // mtu_size
 #endif
-  1,                           // frame_parallel_decoding_mode
-  NO_AQ,                       // aq_mode
-  0,                           // frame_periodic_delta_q
-  AOM_BITS_8,                  // Bit depth
-  AOM_CONTENT_DEFAULT,         // content
-  AOM_CS_UNKNOWN,              // color space
-  0,                           // color range
-  0,                           // render width
-  0,                           // render height
-  AOM_SUPERBLOCK_SIZE_DYNAMIC  // superblock_size
+  1,                            // frame_parallel_decoding_mode
+  NO_AQ,                        // aq_mode
+  0,                            // frame_periodic_delta_q
+  AOM_BITS_8,                   // Bit depth
+  AOM_CONTENT_DEFAULT,          // content
+  AOM_CS_UNKNOWN,               // color space
+  0,                            // color range
+  0,                            // render width
+  0,                            // render height
+  AOM_SUPERBLOCK_SIZE_DYNAMIC,  // superblock_size
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+  23,  // ans_window_size_log2
+#endif
 };
 
 struct aom_codec_alg_priv {
@@ -310,6 +316,9 @@
   }
   RANGE_CHECK(extra_cfg, color_space, AOM_CS_UNKNOWN, AOM_CS_SRGB);
   RANGE_CHECK(extra_cfg, color_range, 0, 1);
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+  RANGE_CHECK(extra_cfg, ans_window_size_log2, 8, 23);
+#endif
   return AOM_CODEC_OK;
 }
 
@@ -475,6 +484,9 @@
 #if CONFIG_EXT_PARTITION
   oxcf->superblock_size = extra_cfg->superblock_size;
 #endif  // CONFIG_EXT_PARTITION
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+  oxcf->ans_window_size_log2 = extra_cfg->ans_window_size_log2;
+#endif  // CONFIG_ANS && ANS_MAX_SYMBOLS
 
 #if CONFIG_EXT_TILE
   {
@@ -1334,6 +1346,15 @@
   return update_extra_cfg(ctx, &extra_cfg);
 }
 
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+static aom_codec_err_t ctrl_set_ans_window_size_log2(aom_codec_alg_priv_t *ctx,
+                                                     va_list args) {
+  struct av1_extracfg extra_cfg = ctx->extra_cfg;
+  extra_cfg.ans_window_size_log2 = CAST(AV1E_SET_ANS_WINDOW_SIZE_LOG2, args);
+  return update_extra_cfg(ctx, &extra_cfg);
+}
+#endif
+
 static aom_codec_ctrl_fn_map_t encoder_ctrl_maps[] = {
   { AOM_COPY_REFERENCE, ctrl_copy_reference },
   { AOME_USE_REFERENCE, ctrl_use_reference },
@@ -1384,6 +1405,9 @@
   { AV1E_SET_MAX_GF_INTERVAL, ctrl_set_max_gf_interval },
   { AV1E_SET_RENDER_SIZE, ctrl_set_render_size },
   { AV1E_SET_SUPERBLOCK_SIZE, ctrl_set_superblock_size },
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+  { AV1E_SET_ANS_WINDOW_SIZE_LOG2, ctrl_set_ans_window_size_log2 },
+#endif
 
   // Getters
   { AOME_GET_LAST_QUANTIZER, ctrl_get_quantizer },
diff --git a/av1/common/onyxc_int.h b/av1/common/onyxc_int.h
index 4315431..ca93bef 100644
--- a/av1/common/onyxc_int.h
+++ b/av1/common/onyxc_int.h
@@ -16,6 +16,9 @@
 #include "./av1_rtcd.h"
 #include "aom/internal/aom_codec_internal.h"
 #include "aom_util/aom_thread.h"
+#if CONFIG_ANS
+#include "aom_dsp/ans.h"
+#endif
 #include "av1/common/alloccommon.h"
 #include "av1/common/entropy.h"
 #include "av1/common/entropymode.h"
@@ -422,6 +425,9 @@
   int refresh_mask;
   int invalid_delta_frame_id_minus1;
 #endif
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+  int ans_window_size_log2;
+#endif
 } AV1_COMMON;
 
 #if CONFIG_REFERENCE_BUFFER
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index b1d09ff..6788302 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -2348,8 +2348,11 @@
 static void setup_bool_decoder(const uint8_t *data, const uint8_t *data_end,
                                const size_t read_size,
                                struct aom_internal_error_info *error_info,
-                               aom_reader *r, aom_decrypt_cb decrypt_cb,
-                               void *decrypt_state) {
+                               aom_reader *r,
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+                               int window_size,
+#endif  // CONFIG_ANS && ANS_MAX_SYMBOLS
+                               aom_decrypt_cb decrypt_cb, void *decrypt_state) {
   // Validate the calculated partition length. If the buffer
   // described by the partition can't be fully read, then restrict
   // it to the portion that can be (for EC mode) or throw an error.
@@ -2358,7 +2361,7 @@
                        "Truncated packet or corrupt tile length");
 
 #if CONFIG_ANS && ANS_MAX_SYMBOLS
-  r->window_size = ANS_MAX_SYMBOLS;
+  r->window_size = window_size;
 #endif
   if (aom_reader_init(r, data, read_size, decrypt_cb, decrypt_state))
     aom_internal_error(error_info, AOM_CODEC_MEM_ERROR,
@@ -3393,7 +3396,11 @@
 #endif
       av1_tile_init(&td->xd.tile, td->cm, tile_row, tile_col);
       setup_bool_decoder(buf->data, data_end, buf->size, &cm->error,
-                         &td->bit_reader, pbi->decrypt_cb, pbi->decrypt_state);
+                         &td->bit_reader,
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+                         1 << cm->ans_window_size_log2,
+#endif  // CONFIG_ANS && ANS_MAX_SYMBOLS
+                         pbi->decrypt_cb, pbi->decrypt_state);
 #if CONFIG_ACCOUNTING
       if (pbi->acct_enabled) {
         td->bit_reader.accounting = &pbi->accounting;
@@ -3746,8 +3753,11 @@
         av1_tile_init(tile_info, cm, tile_row, buf->col);
         av1_tile_init(&twd->xd.tile, cm, tile_row, buf->col);
         setup_bool_decoder(buf->data, data_end, buf->size, &cm->error,
-                           &twd->bit_reader, pbi->decrypt_cb,
-                           pbi->decrypt_state);
+                           &twd->bit_reader,
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+                           1 << cm->ans_window_size_log2,
+#endif  // CONFIG_ANS && ANS_MAX_SYMBOLS
+                           pbi->decrypt_cb, pbi->decrypt_state);
         av1_init_macroblockd(cm, &twd->xd,
 #if CONFIG_PVQ
                              twd->pvq_ref_coeff,
@@ -4010,6 +4020,9 @@
       memset(&cm->ref_frame_map, -1, sizeof(cm->ref_frame_map));
       pbi->need_resync = 0;
     }
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+    cm->ans_window_size_log2 = aom_rb_read_literal(rb, 4) + 8;
+#endif  // CONFIG_ANS && ANS_MAX_SYMBOLS
 #if CONFIG_PALETTE
     cm->allow_screen_content_tools = aom_rb_read_bit(rb);
 #endif  // CONFIG_PALETTE
@@ -4049,6 +4062,9 @@
         memset(&cm->ref_frame_map, -1, sizeof(cm->ref_frame_map));
         pbi->need_resync = 0;
       }
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+      cm->ans_window_size_log2 = aom_rb_read_literal(rb, 4) + 8;
+#endif
     } else if (pbi->need_resync != 1) { /* Skip if need resync */
       pbi->refresh_frame_flags = aom_rb_read_literal(rb, REF_FRAMES);
 
@@ -4380,7 +4396,7 @@
 #endif
 
 #if CONFIG_ANS && ANS_MAX_SYMBOLS
-  r.window_size = ANS_MAX_SYMBOLS;
+  r.window_size = 1 << cm->ans_window_size_log2;
 #endif
   if (aom_reader_init(&r, data, partition_size, pbi->decrypt_cb,
                       pbi->decrypt_state))
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index 2746d7b..47df82a 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -4124,6 +4124,11 @@
     write_sync_code(wb);
     write_bitdepth_colorspace_sampling(cm, wb);
     write_frame_size(cm, wb);
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+    assert(cpi->common.ans_window_size_log2 >= 8);
+    assert(cpi->common.ans_window_size_log2 < 24);
+    aom_wb_write_literal(wb, cpi->common.ans_window_size_log2 - 8, 4);
+#endif  // CONFIG_ANS && ANS_MAX_SYMBOLS
 #if CONFIG_PALETTE
     aom_wb_write_bit(wb, cm->allow_screen_content_tools);
 #endif  // CONFIG_PALETTE
@@ -4159,6 +4164,12 @@
       aom_wb_write_literal(wb, get_refresh_mask(cpi), REF_FRAMES);
 #endif  // CONFIG_EXT_REFS
       write_frame_size(cm, wb);
+
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+      assert(cpi->common.ans_window_size_log2 >= 8);
+      assert(cpi->common.ans_window_size_log2 < 24);
+      aom_wb_write_literal(wb, cpi->common.ans_window_size_log2 - 8, 4);
+#endif  // CONFIG_ANS && ANS_MAX_SYMBOLS
     } else {
       MV_REFERENCE_FRAME ref_frame;
 
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index f5f5dc6..c545eab 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -799,9 +799,8 @@
     unsigned int tokens = get_token_alloc(cm->mb_rows, cm->mb_cols);
     CHECK_MEM_ERROR(cm, cpi->tile_tok[0][0],
                     aom_calloc(tokens, sizeof(*cpi->tile_tok[0][0])));
-#if CONFIG_ANS
-    aom_buf_ans_alloc(&cpi->buf_ans, &cm->error,
-                      ANS_MAX_SYMBOLS ? ANS_MAX_SYMBOLS : tokens);
+#if CONFIG_ANS && !ANS_MAX_SYMBOLS
+    aom_buf_ans_alloc(&cpi->buf_ans, &cm->error, (int)tokens);
 #endif  // CONFIG_ANS
   }
 
@@ -2034,6 +2033,15 @@
 #if CONFIG_AOM_HIGHBITDEPTH
   highbd_set_var_fns(cpi);
 #endif
+
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+  cpi->common.ans_window_size_log2 = cpi->oxcf.ans_window_size_log2;
+  if (cpi->buf_ans.size != (1 << cpi->common.ans_window_size_log2)) {
+    aom_buf_ans_free(&cpi->buf_ans);
+    aom_buf_ans_alloc(&cpi->buf_ans, &cpi->common.error,
+                      1 << cpi->common.ans_window_size_log2);
+  }
+#endif  // CONFIG_ANS && ANS_MAX_SYMBOLS
 }
 
 #ifndef M_LOG2_E
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 83ed874..9416ad7 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -23,6 +23,7 @@
 #include "av1/common/onyxc_int.h"
 #include "av1/encoder/aq_cyclicrefresh.h"
 #if CONFIG_ANS
+#include "aom_dsp/ans.h"
 #include "aom_dsp/buf_ans.h"
 #endif
 #include "av1/encoder/context_tree.h"
@@ -264,6 +265,9 @@
 #if CONFIG_EXT_PARTITION
   aom_superblock_size_t superblock_size;
 #endif  // CONFIG_EXT_PARTITION
+#if CONFIG_ANS && ANS_MAX_SYMBOLS
+  int ans_window_size_log2;
+#endif  // CONFIG_ANS && ANS_MAX_SYMBOLS
 } AV1EncoderConfig;
 
 static INLINE int is_lossless_requested(const AV1EncoderConfig *cfg) {