Use weighted averaging of CDFs in enc row-mt

In case of row-based multi-threading of encoder, since we always
keep a top-right sync, we can average the top-right SB's CDFs and
the left SB's CDFs and use the same for current SB's encoding to
improve the performance.

Using weighted averaging i.e., giving more weight (75%) to left
SB’s CDF and less weight (25%) to top-right SB’s CDF, further
helps in reducing the BD rate gap between enc row-mt and tile-mt.

STATS_CHANGED for row-mt.

Change-Id: Ic66a4710b198d649f6077d7437d789c96e743d4a
diff --git a/av1/encoder/av1_multi_thread.c b/av1/encoder/av1_multi_thread.c
index a0c556e..ff4c4ac 100644
--- a/av1/encoder/av1_multi_thread.c
+++ b/av1/encoder/av1_multi_thread.c
@@ -35,6 +35,12 @@
           &cpi->tile_data[tile_row * multi_thread_ctxt->allocated_tile_cols +
                           tile_col];
       av1_row_mt_sync_mem_alloc(&this_tile->row_mt_sync, cm, max_sb_rows);
+      CHECK_MEM_ERROR(
+          cm, this_tile->row_ctx,
+          (FRAME_CONTEXT *)aom_memalign(
+              16, AOMMAX(1, (av1_get_sb_cols_in_tile(cm, this_tile->tile_info) -
+                             1)) *
+                      sizeof(*this_tile->row_ctx)));
     }
   }
 }
@@ -53,6 +59,7 @@
           &cpi->tile_data[tile_row * multi_thread_ctxt->allocated_tile_cols +
                           tile_col];
       av1_row_mt_sync_mem_dealloc(&this_tile->row_mt_sync);
+      aom_free(this_tile->row_ctx);
     }
   }
   multi_thread_ctxt->allocated_sb_rows = 0;
diff --git a/av1/encoder/block.h b/av1/encoder/block.h
index 0471de5..5df6677 100644
--- a/av1/encoder/block.h
+++ b/av1/encoder/block.h
@@ -277,7 +277,7 @@
   CONV_BUF_TYPE *tmp_conv_dst;
   uint8_t *tmp_obmc_bufs[2];
 
-  FRAME_CONTEXT *backup_tile_ctx;
+  FRAME_CONTEXT *row_ctx;
   // This context will be used to update color_map_cdf pointer which would be
   // used during pack bitstream. For single thread and tile-multithreading case
   // this ponter will be same as xd->tile_ctx, but for the case of row-mt:
diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c
index 38da2a1..85caff65 100644
--- a/av1/encoder/encodeframe.c
+++ b/av1/encoder/encodeframe.c
@@ -5317,6 +5317,182 @@
   }
 }
 
+#define AVG_CDF_WEIGHT_LEFT 3
+#define AVG_CDF_WEIGHT_TOP_RIGHT 1
+
+static void avg_cdf_symbol(aom_cdf_prob *cdf_ptr_left, aom_cdf_prob *cdf_ptr_tr,
+                           int num_cdfs, int cdf_stride, int nsymbs,
+                           int wt_left, int wt_tr) {
+  for (int i = 0; i < num_cdfs; i++) {
+    for (int j = 0; j <= nsymbs; j++) {
+      cdf_ptr_left[i * cdf_stride + j] =
+          (aom_cdf_prob)(((int)cdf_ptr_left[i * cdf_stride + j] * wt_left +
+                          (int)cdf_ptr_tr[i * cdf_stride + j] * wt_tr +
+                          ((wt_left + wt_tr) / 2)) /
+                         (wt_left + wt_tr));
+      assert(cdf_ptr_left[i * cdf_stride + j] >= 0 &&
+             cdf_ptr_left[i * cdf_stride + j] < CDF_PROB_TOP);
+    }
+  }
+}
+
+#define AVERAGE_CDF(cname_left, cname_tr, nsymbs) \
+  AVG_CDF_STRIDE(cname_left, cname_tr, nsymbs, CDF_SIZE(nsymbs))
+
+#define AVG_CDF_STRIDE(cname_left, cname_tr, nsymbs, cdf_stride)           \
+  do {                                                                     \
+    aom_cdf_prob *cdf_ptr_left = (aom_cdf_prob *)cname_left;               \
+    aom_cdf_prob *cdf_ptr_tr = (aom_cdf_prob *)cname_tr;                   \
+    int array_size = (int)sizeof(cname_left) / sizeof(aom_cdf_prob);       \
+    int num_cdfs = array_size / cdf_stride;                                \
+    avg_cdf_symbol(cdf_ptr_left, cdf_ptr_tr, num_cdfs, cdf_stride, nsymbs, \
+                   wt_left, wt_tr);                                        \
+  } while (0)
+
+static void avg_nmv(nmv_context *nmv_left, nmv_context *nmv_tr, int wt_left,
+                    int wt_tr) {
+  AVERAGE_CDF(nmv_left->joints_cdf, nmv_tr->joints_cdf, 4);
+  for (int i = 0; i < 2; i++) {
+    AVERAGE_CDF(nmv_left->comps[i].classes_cdf, nmv_tr->comps[i].classes_cdf,
+                MV_CLASSES);
+    AVERAGE_CDF(nmv_left->comps[i].class0_fp_cdf,
+                nmv_tr->comps[i].class0_fp_cdf, MV_FP_SIZE);
+    AVERAGE_CDF(nmv_left->comps[i].fp_cdf, nmv_tr->comps[i].fp_cdf, MV_FP_SIZE);
+    AVERAGE_CDF(nmv_left->comps[i].sign_cdf, nmv_tr->comps[i].sign_cdf, 2);
+    AVERAGE_CDF(nmv_left->comps[i].class0_hp_cdf,
+                nmv_tr->comps[i].class0_hp_cdf, 2);
+    AVERAGE_CDF(nmv_left->comps[i].hp_cdf, nmv_tr->comps[i].hp_cdf, 2);
+    AVERAGE_CDF(nmv_left->comps[i].class0_cdf, nmv_tr->comps[i].class0_cdf,
+                CLASS0_SIZE);
+    AVERAGE_CDF(nmv_left->comps[i].bits_cdf, nmv_tr->comps[i].bits_cdf, 2);
+  }
+}
+
+// In case of row-based multi-threading of encoder, since we always
+// keep a top - right sync, we can average the top - right SB's CDFs and
+// the left SB's CDFs and use the same for current SB's encoding to
+// improve the performance. This function facilitates the averaging
+// of CDF and used only when row-mt is enabled in encoder.
+static void avg_cdf_symbols(FRAME_CONTEXT *ctx_left, FRAME_CONTEXT *ctx_tr,
+                            int wt_left, int wt_tr) {
+  AVERAGE_CDF(ctx_left->txb_skip_cdf, ctx_tr->txb_skip_cdf, 2);
+  AVERAGE_CDF(ctx_left->eob_extra_cdf, ctx_tr->eob_extra_cdf, 2);
+  AVERAGE_CDF(ctx_left->dc_sign_cdf, ctx_tr->dc_sign_cdf, 2);
+  AVERAGE_CDF(ctx_left->eob_flag_cdf16, ctx_tr->eob_flag_cdf16, 5);
+  AVERAGE_CDF(ctx_left->eob_flag_cdf32, ctx_tr->eob_flag_cdf32, 6);
+  AVERAGE_CDF(ctx_left->eob_flag_cdf64, ctx_tr->eob_flag_cdf64, 7);
+  AVERAGE_CDF(ctx_left->eob_flag_cdf128, ctx_tr->eob_flag_cdf128, 8);
+  AVERAGE_CDF(ctx_left->eob_flag_cdf256, ctx_tr->eob_flag_cdf256, 9);
+  AVERAGE_CDF(ctx_left->eob_flag_cdf512, ctx_tr->eob_flag_cdf512, 10);
+  AVERAGE_CDF(ctx_left->eob_flag_cdf1024, ctx_tr->eob_flag_cdf1024, 11);
+  AVERAGE_CDF(ctx_left->coeff_base_eob_cdf, ctx_tr->coeff_base_eob_cdf, 3);
+  AVERAGE_CDF(ctx_left->coeff_base_cdf, ctx_tr->coeff_base_cdf, 4);
+  AVERAGE_CDF(ctx_left->coeff_br_cdf, ctx_tr->coeff_br_cdf, BR_CDF_SIZE);
+  AVERAGE_CDF(ctx_left->newmv_cdf, ctx_tr->newmv_cdf, 2);
+  AVERAGE_CDF(ctx_left->zeromv_cdf, ctx_tr->zeromv_cdf, 2);
+  AVERAGE_CDF(ctx_left->refmv_cdf, ctx_tr->refmv_cdf, 2);
+  AVERAGE_CDF(ctx_left->drl_cdf, ctx_tr->drl_cdf, 2);
+  AVERAGE_CDF(ctx_left->inter_compound_mode_cdf,
+              ctx_tr->inter_compound_mode_cdf, INTER_COMPOUND_MODES);
+  AVERAGE_CDF(ctx_left->compound_type_cdf, ctx_tr->compound_type_cdf,
+              COMPOUND_TYPES - 1);
+  AVERAGE_CDF(ctx_left->wedge_idx_cdf, ctx_tr->wedge_idx_cdf, 16);
+  AVERAGE_CDF(ctx_left->interintra_cdf, ctx_tr->interintra_cdf, 2);
+  AVERAGE_CDF(ctx_left->wedge_interintra_cdf, ctx_tr->wedge_interintra_cdf, 2);
+  AVERAGE_CDF(ctx_left->interintra_mode_cdf, ctx_tr->interintra_mode_cdf,
+              INTERINTRA_MODES);
+  AVERAGE_CDF(ctx_left->motion_mode_cdf, ctx_tr->motion_mode_cdf, MOTION_MODES);
+  AVERAGE_CDF(ctx_left->obmc_cdf, ctx_tr->obmc_cdf, 2);
+  AVERAGE_CDF(ctx_left->palette_y_size_cdf, ctx_tr->palette_y_size_cdf,
+              PALETTE_SIZES);
+  AVERAGE_CDF(ctx_left->palette_uv_size_cdf, ctx_tr->palette_uv_size_cdf,
+              PALETTE_SIZES);
+  for (int j = 0; j < PALETTE_SIZES; j++) {
+    int nsymbs = j + PALETTE_MIN_SIZE;
+    AVG_CDF_STRIDE(ctx_left->palette_y_color_index_cdf[j],
+                   ctx_tr->palette_y_color_index_cdf[j], nsymbs,
+                   CDF_SIZE(PALETTE_COLORS));
+    AVG_CDF_STRIDE(ctx_left->palette_uv_color_index_cdf[j],
+                   ctx_tr->palette_uv_color_index_cdf[j], nsymbs,
+                   CDF_SIZE(PALETTE_COLORS));
+  }
+  AVERAGE_CDF(ctx_left->palette_y_mode_cdf, ctx_tr->palette_y_mode_cdf, 2);
+  AVERAGE_CDF(ctx_left->palette_uv_mode_cdf, ctx_tr->palette_uv_mode_cdf, 2);
+  AVERAGE_CDF(ctx_left->comp_inter_cdf, ctx_tr->comp_inter_cdf, 2);
+  AVERAGE_CDF(ctx_left->single_ref_cdf, ctx_tr->single_ref_cdf, 2);
+  AVERAGE_CDF(ctx_left->comp_ref_type_cdf, ctx_tr->comp_ref_type_cdf, 2);
+  AVERAGE_CDF(ctx_left->uni_comp_ref_cdf, ctx_tr->uni_comp_ref_cdf, 2);
+  AVERAGE_CDF(ctx_left->comp_ref_cdf, ctx_tr->comp_ref_cdf, 2);
+  AVERAGE_CDF(ctx_left->comp_bwdref_cdf, ctx_tr->comp_bwdref_cdf, 2);
+  AVERAGE_CDF(ctx_left->txfm_partition_cdf, ctx_tr->txfm_partition_cdf, 2);
+  AVERAGE_CDF(ctx_left->compound_index_cdf, ctx_tr->compound_index_cdf, 2);
+  AVERAGE_CDF(ctx_left->comp_group_idx_cdf, ctx_tr->comp_group_idx_cdf, 2);
+  AVERAGE_CDF(ctx_left->skip_mode_cdfs, ctx_tr->skip_mode_cdfs, 2);
+  AVERAGE_CDF(ctx_left->skip_cdfs, ctx_tr->skip_cdfs, 2);
+  AVERAGE_CDF(ctx_left->intra_inter_cdf, ctx_tr->intra_inter_cdf, 2);
+  avg_nmv(&ctx_left->nmvc, &ctx_tr->nmvc, wt_left, wt_tr);
+  avg_nmv(&ctx_left->ndvc, &ctx_tr->ndvc, wt_left, wt_tr);
+  AVERAGE_CDF(ctx_left->intrabc_cdf, ctx_tr->intrabc_cdf, 2);
+  AVERAGE_CDF(ctx_left->seg.tree_cdf, ctx_tr->seg.tree_cdf, MAX_SEGMENTS);
+  AVERAGE_CDF(ctx_left->seg.pred_cdf, ctx_tr->seg.pred_cdf, 2);
+  AVERAGE_CDF(ctx_left->seg.spatial_pred_seg_cdf,
+              ctx_tr->seg.spatial_pred_seg_cdf, MAX_SEGMENTS);
+  AVERAGE_CDF(ctx_left->filter_intra_cdfs, ctx_tr->filter_intra_cdfs, 2);
+  AVERAGE_CDF(ctx_left->filter_intra_mode_cdf, ctx_tr->filter_intra_mode_cdf,
+              FILTER_INTRA_MODES);
+  AVERAGE_CDF(ctx_left->switchable_restore_cdf, ctx_tr->switchable_restore_cdf,
+              RESTORE_SWITCHABLE_TYPES);
+  AVERAGE_CDF(ctx_left->wiener_restore_cdf, ctx_tr->wiener_restore_cdf, 2);
+  AVERAGE_CDF(ctx_left->sgrproj_restore_cdf, ctx_tr->sgrproj_restore_cdf, 2);
+  AVERAGE_CDF(ctx_left->y_mode_cdf, ctx_tr->y_mode_cdf, INTRA_MODES);
+  AVG_CDF_STRIDE(ctx_left->uv_mode_cdf[0], ctx_tr->uv_mode_cdf[0],
+                 UV_INTRA_MODES - 1, CDF_SIZE(UV_INTRA_MODES));
+  AVERAGE_CDF(ctx_left->uv_mode_cdf[1], ctx_tr->uv_mode_cdf[1], UV_INTRA_MODES);
+  for (int i = 0; i < PARTITION_CONTEXTS; i++) {
+    if (i < 4) {
+      AVG_CDF_STRIDE(ctx_left->partition_cdf[i], ctx_tr->partition_cdf[i], 4,
+                     CDF_SIZE(10));
+    } else if (i < 16) {
+      AVERAGE_CDF(ctx_left->partition_cdf[i], ctx_tr->partition_cdf[i], 10);
+    } else {
+      AVG_CDF_STRIDE(ctx_left->partition_cdf[i], ctx_tr->partition_cdf[i], 8,
+                     CDF_SIZE(10));
+    }
+  }
+  AVERAGE_CDF(ctx_left->switchable_interp_cdf, ctx_tr->switchable_interp_cdf,
+              SWITCHABLE_FILTERS);
+  AVERAGE_CDF(ctx_left->kf_y_cdf, ctx_tr->kf_y_cdf, INTRA_MODES);
+  AVERAGE_CDF(ctx_left->angle_delta_cdf, ctx_tr->angle_delta_cdf,
+              2 * MAX_ANGLE_DELTA + 1);
+  AVG_CDF_STRIDE(ctx_left->tx_size_cdf[0], ctx_tr->tx_size_cdf[0], MAX_TX_DEPTH,
+                 CDF_SIZE(MAX_TX_DEPTH + 1));
+  AVERAGE_CDF(ctx_left->tx_size_cdf[1], ctx_tr->tx_size_cdf[1],
+              MAX_TX_DEPTH + 1);
+  AVERAGE_CDF(ctx_left->tx_size_cdf[2], ctx_tr->tx_size_cdf[2],
+              MAX_TX_DEPTH + 1);
+  AVERAGE_CDF(ctx_left->tx_size_cdf[3], ctx_tr->tx_size_cdf[3],
+              MAX_TX_DEPTH + 1);
+  AVERAGE_CDF(ctx_left->delta_q_cdf, ctx_tr->delta_q_cdf, DELTA_Q_PROBS + 1);
+  AVERAGE_CDF(ctx_left->delta_lf_cdf, ctx_tr->delta_lf_cdf, DELTA_LF_PROBS + 1);
+  for (int i = 0; i < FRAME_LF_COUNT; i++) {
+    AVERAGE_CDF(ctx_left->delta_lf_multi_cdf[i], ctx_tr->delta_lf_multi_cdf[i],
+                DELTA_LF_PROBS + 1);
+  }
+  AVG_CDF_STRIDE(ctx_left->intra_ext_tx_cdf[1], ctx_tr->intra_ext_tx_cdf[1], 7,
+                 CDF_SIZE(TX_TYPES));
+  AVG_CDF_STRIDE(ctx_left->intra_ext_tx_cdf[2], ctx_tr->intra_ext_tx_cdf[2], 5,
+                 CDF_SIZE(TX_TYPES));
+  AVG_CDF_STRIDE(ctx_left->inter_ext_tx_cdf[1], ctx_tr->inter_ext_tx_cdf[1], 16,
+                 CDF_SIZE(TX_TYPES));
+  AVG_CDF_STRIDE(ctx_left->inter_ext_tx_cdf[2], ctx_tr->inter_ext_tx_cdf[2], 12,
+                 CDF_SIZE(TX_TYPES));
+  AVG_CDF_STRIDE(ctx_left->inter_ext_tx_cdf[3], ctx_tr->inter_ext_tx_cdf[3], 2,
+                 CDF_SIZE(TX_TYPES));
+  AVERAGE_CDF(ctx_left->cfl_sign_cdf, ctx_tr->cfl_sign_cdf, CFL_JOINT_SIGNS);
+  AVERAGE_CDF(ctx_left->cfl_alpha_cdf, ctx_tr->cfl_alpha_cdf,
+              CFL_ALPHABET_SIZE);
+}
+
 static void encode_rd_sb_row(AV1_COMP *cpi, ThreadData *td,
                              TileDataEnc *tile_data, int mi_row,
                              TOKENEXTRA **tp) {
@@ -5350,10 +5526,22 @@
        mi_col < tile_info->mi_col_end; mi_col += mib_size, sb_col_in_tile++) {
     (*(cpi->row_mt_sync_read_ptr))(&tile_data->row_mt_sync, sb_row,
                                    sb_col_in_tile);
-    if ((cpi->row_mt == 1) && (tile_info->mi_col_start == mi_col) &&
-        (tile_info->mi_row_start != mi_row)) {
-      // restore frame context of 1st column sb
-      memcpy(xd->tile_ctx, x->backup_tile_ctx, sizeof(*xd->tile_ctx));
+    if ((cpi->row_mt == 1) && (tile_info->mi_row_start != mi_row)) {
+      if ((tile_info->mi_col_start == mi_col)) {
+        // restore frame context of 1st column sb
+        memcpy(xd->tile_ctx, x->row_ctx, sizeof(*xd->tile_ctx));
+      } else {
+        if (tile_data->allow_update_cdf) {
+          int wt_left = AVG_CDF_WEIGHT_LEFT;
+          int wt_tr = AVG_CDF_WEIGHT_TOP_RIGHT;
+          if (tile_info->mi_col_end > (mi_col + mib_size))
+            avg_cdf_symbols(xd->tile_ctx, x->row_ctx + sb_col_in_tile, wt_left,
+                            wt_tr);
+          else
+            avg_cdf_symbols(xd->tile_ctx, x->row_ctx + sb_col_in_tile - 1,
+                            wt_left, wt_tr);
+        }
+      }
     }
     av1_fill_coeff_costs(&td->mb, xd->tile_ctx, num_planes);
     av1_fill_mode_rates(cm, x, xd->tile_ctx);
@@ -5459,11 +5647,12 @@
       av1_inter_mode_data_fit(tile_data, x->rdmult);
     }
 #endif
-    if (cpi->row_mt == 1) {
-      int update_context = 0;
-      update_context = sb_cols_in_tile == 1 || sb_col_in_tile == 1;
-      if (update_context)
-        memcpy(x->backup_tile_ctx, xd->tile_ctx, sizeof(*xd->tile_ctx));
+    if ((cpi->row_mt == 1) && (tile_info->mi_row_end > (mi_row + mib_size))) {
+      if (sb_cols_in_tile == 1)
+        memcpy(x->row_ctx, xd->tile_ctx, sizeof(*xd->tile_ctx));
+      else if (sb_col_in_tile >= 1)
+        memcpy(x->row_ctx + sb_col_in_tile - 1, xd->tile_ctx,
+               sizeof(*xd->tile_ctx));
     }
     (*(cpi->row_mt_sync_write_ptr))(&tile_data->row_mt_sync, sb_row,
                                     sb_col_in_tile, sb_cols_in_tile);
@@ -5654,7 +5843,6 @@
       cpi->td.intrabc_used = 0;
       cpi->td.mb.e_mbd.tile_ctx = &this_tile->tctx;
       cpi->td.mb.tile_pb_ctx = &this_tile->tctx;
-      cpi->td.mb.backup_tile_ctx = &this_tile->backup_tctx;
       av1_encode_tile(cpi, &cpi->td, tile_row, tile_col);
       cpi->intrabc_used |= cpi->td.intrabc_used;
     }
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index 2a522e2..609f6a0 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -511,7 +511,7 @@
   int ex_search_count;
   CFL_CTX cfl;
   DECLARE_ALIGNED(16, FRAME_CONTEXT, tctx);
-  DECLARE_ALIGNED(16, FRAME_CONTEXT, backup_tctx);
+  FRAME_CONTEXT *row_ctx;
   uint8_t allow_update_cdf;
 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
   InterModeRdModel inter_mode_rd_models[BLOCK_SIZES_ALL];
diff --git a/av1/encoder/ethread.c b/av1/encoder/ethread.c
index 3f720ee..a82f2af 100644
--- a/av1/encoder/ethread.c
+++ b/av1/encoder/ethread.c
@@ -310,7 +310,7 @@
 
     td->mb.e_mbd.tile_ctx = td->tctx;
     td->mb.tile_pb_ctx = &this_tile->tctx;
-    td->mb.backup_tile_ctx = &this_tile->backup_tctx;
+    td->mb.row_ctx = this_tile->row_ctx;
     if (current_mi_row == this_tile->tile_info.mi_row_start)
       memcpy(td->mb.e_mbd.tile_ctx, &this_tile->tctx, sizeof(FRAME_CONTEXT));
     av1_init_above_context(cm, &td->mb.e_mbd, tile_row);
@@ -355,7 +355,6 @@
         &cpi->tile_data[tile_row * cm->tile_cols + tile_col];
     thread_data->td->mb.e_mbd.tile_ctx = &this_tile->tctx;
     thread_data->td->mb.tile_pb_ctx = &this_tile->tctx;
-    thread_data->td->mb.backup_tile_ctx = &this_tile->backup_tctx;
     av1_encode_tile(cpi, thread_data->td, tile_row, tile_col);
   }