Refactor av1_get_overlap_area

Simplify the code and add unit test for it.

Change-Id: I95c71e7f371aeaf81e2f3f103b0c9f8a9a042966
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index a8e0ec1..70b9407 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -869,33 +869,16 @@
   return round;
 }
 
-static int get_overlap_area(int grid_pos_row, int grid_pos_col, int ref_pos_row,
-                            int ref_pos_col, int block, BLOCK_SIZE bsize) {
-  int width = 0, height = 0;
-  int bw = 4 << mi_size_wide_log2[bsize];
-  int bh = 4 << mi_size_high_log2[bsize];
-
-  switch (block) {
-    case 0:
-      width = grid_pos_col + bw - ref_pos_col;
-      height = grid_pos_row + bh - ref_pos_row;
-      break;
-    case 1:
-      width = ref_pos_col + bw - grid_pos_col;
-      height = grid_pos_row + bh - ref_pos_row;
-      break;
-    case 2:
-      width = grid_pos_col + bw - ref_pos_col;
-      height = ref_pos_row + bh - grid_pos_row;
-      break;
-    case 3:
-      width = ref_pos_col + bw - grid_pos_col;
-      height = ref_pos_row + bh - grid_pos_row;
-      break;
-    default: assert(0);
+int av1_get_overlap_area(int row_a, int col_a, int row_b, int col_b, int width,
+                         int height) {
+  int min_row = AOMMAX(row_a, row_b);
+  int max_row = AOMMIN(row_a + height, row_b + height);
+  int min_col = AOMMAX(col_a, col_b);
+  int max_col = AOMMIN(col_a + width, col_b + width);
+  if (min_row < max_row && min_col < max_col) {
+    return (max_row - min_row) * (max_col - min_col);
   }
-  int overlap_area = width * height;
-  return overlap_area;
+  return 0;
 }
 
 int av1_tpl_ptr_pos(int mi_row, int mi_col, int stride, uint8_t right_shift) {
@@ -988,8 +971,8 @@
 
     if (grid_pos_row >= 0 && grid_pos_row < ref_tpl_frame->mi_rows * MI_SIZE &&
         grid_pos_col >= 0 && grid_pos_col < ref_tpl_frame->mi_cols * MI_SIZE) {
-      int overlap_area = get_overlap_area(
-          grid_pos_row, grid_pos_col, ref_pos_row, ref_pos_col, block, bsize);
+      int overlap_area = av1_get_overlap_area(grid_pos_row, grid_pos_col,
+                                              ref_pos_row, ref_pos_col, bw, bh);
       int ref_mi_row = round_floor(grid_pos_row, bh) * mi_height;
       int ref_mi_col = round_floor(grid_pos_col, bw) * mi_width;
       assert((1 << block_mis_log2) == mi_height);
diff --git a/av1/encoder/tpl_model.h b/av1/encoder/tpl_model.h
index 1937f3e..75f2b99 100644
--- a/av1/encoder/tpl_model.h
+++ b/av1/encoder/tpl_model.h
@@ -352,6 +352,24 @@
 int64_t av1_delta_rate_cost(int64_t delta_rate, int64_t recrf_dist,
                             int64_t srcrf_dist, int pix_num);
 
+/*!\brief  Compute the overlap area between two blocks with the same size
+ *
+ *\ingroup tpl_modelling
+ *
+ * If there is no overlap, this function should return zero.
+ *
+ * \param[in]    row_a  row position of the first block
+ * \param[in]    col_a  column position of the first block
+ * \param[in]    row_b  row position of the second block
+ * \param[in]    col_b  column position of the second block
+ * \param[in]    width  width shared by the two blocks
+ * \param[in]    height height shared by the two blocks
+ *
+ * \return overlap area of the two blocks
+ */
+int av1_get_overlap_area(int row_a, int col_a, int row_b, int col_b, int width,
+                         int height);
+
 /*!\endcond */
 #ifdef __cplusplus
 }  // extern "C"
diff --git a/test/tpl_model_test.cc b/test/tpl_model_test.cc
index a53ee43..8d5edcd 100644
--- a/test/tpl_model_test.cc
+++ b/test/tpl_model_test.cc
@@ -108,4 +108,36 @@
   }
 }
 
+TEST(TplModelTest, GetOverlapAreaHasOverlap) {
+  // The block a's area is [10, 17) x [18, 24).
+  // The block b's area is [8, 15) x [17, 23).
+  // The overlapping area between block a and block b is [10, 15) x [18, 23).
+  // Therefore, the size of the area is (15 - 10) * (23 - 18) = 25.
+  int row_a = 10;
+  int col_a = 18;
+  int row_b = 8;
+  int col_b = 17;
+  int height = 7;
+  int width = 6;
+  int overlap_area =
+      av1_get_overlap_area(row_a, col_a, row_b, col_b, width, height);
+  EXPECT_EQ(overlap_area, 25);
+}
+
+TEST(TplModelTest, GetOverlapAreaNoOverlap) {
+  // The block a's area is [10, 14) x [18, 22).
+  // The block b's area is [5, 9) x [5, 9).
+  // Threre is no overlapping area between block a and block b.
+  // Therefore, the return value should be zero.
+  int row_a = 10;
+  int col_a = 18;
+  int row_b = 5;
+  int col_b = 5;
+  int height = 4;
+  int width = 4;
+  int overlap_area =
+      av1_get_overlap_area(row_a, col_a, row_b, col_b, width, height);
+  EXPECT_EQ(overlap_area, 0);
+}
+
 }  // namespace