Refactor av1_get_overlap_area
Simplify the code and add unit test for it.
Change-Id: I95c71e7f371aeaf81e2f3f103b0c9f8a9a042966
diff --git a/av1/encoder/tpl_model.c b/av1/encoder/tpl_model.c
index a8e0ec1..70b9407 100644
--- a/av1/encoder/tpl_model.c
+++ b/av1/encoder/tpl_model.c
@@ -869,33 +869,16 @@
return round;
}
-static int get_overlap_area(int grid_pos_row, int grid_pos_col, int ref_pos_row,
- int ref_pos_col, int block, BLOCK_SIZE bsize) {
- int width = 0, height = 0;
- int bw = 4 << mi_size_wide_log2[bsize];
- int bh = 4 << mi_size_high_log2[bsize];
-
- switch (block) {
- case 0:
- width = grid_pos_col + bw - ref_pos_col;
- height = grid_pos_row + bh - ref_pos_row;
- break;
- case 1:
- width = ref_pos_col + bw - grid_pos_col;
- height = grid_pos_row + bh - ref_pos_row;
- break;
- case 2:
- width = grid_pos_col + bw - ref_pos_col;
- height = ref_pos_row + bh - grid_pos_row;
- break;
- case 3:
- width = ref_pos_col + bw - grid_pos_col;
- height = ref_pos_row + bh - grid_pos_row;
- break;
- default: assert(0);
+int av1_get_overlap_area(int row_a, int col_a, int row_b, int col_b, int width,
+ int height) {
+ int min_row = AOMMAX(row_a, row_b);
+ int max_row = AOMMIN(row_a + height, row_b + height);
+ int min_col = AOMMAX(col_a, col_b);
+ int max_col = AOMMIN(col_a + width, col_b + width);
+ if (min_row < max_row && min_col < max_col) {
+ return (max_row - min_row) * (max_col - min_col);
}
- int overlap_area = width * height;
- return overlap_area;
+ return 0;
}
int av1_tpl_ptr_pos(int mi_row, int mi_col, int stride, uint8_t right_shift) {
@@ -988,8 +971,8 @@
if (grid_pos_row >= 0 && grid_pos_row < ref_tpl_frame->mi_rows * MI_SIZE &&
grid_pos_col >= 0 && grid_pos_col < ref_tpl_frame->mi_cols * MI_SIZE) {
- int overlap_area = get_overlap_area(
- grid_pos_row, grid_pos_col, ref_pos_row, ref_pos_col, block, bsize);
+ int overlap_area = av1_get_overlap_area(grid_pos_row, grid_pos_col,
+ ref_pos_row, ref_pos_col, bw, bh);
int ref_mi_row = round_floor(grid_pos_row, bh) * mi_height;
int ref_mi_col = round_floor(grid_pos_col, bw) * mi_width;
assert((1 << block_mis_log2) == mi_height);
diff --git a/av1/encoder/tpl_model.h b/av1/encoder/tpl_model.h
index 1937f3e..75f2b99 100644
--- a/av1/encoder/tpl_model.h
+++ b/av1/encoder/tpl_model.h
@@ -352,6 +352,24 @@
int64_t av1_delta_rate_cost(int64_t delta_rate, int64_t recrf_dist,
int64_t srcrf_dist, int pix_num);
+/*!\brief Compute the overlap area between two blocks with the same size
+ *
+ *\ingroup tpl_modelling
+ *
+ * If there is no overlap, this function should return zero.
+ *
+ * \param[in] row_a row position of the first block
+ * \param[in] col_a column position of the first block
+ * \param[in] row_b row position of the second block
+ * \param[in] col_b column position of the second block
+ * \param[in] width width shared by the two blocks
+ * \param[in] height height shared by the two blocks
+ *
+ * \return overlap area of the two blocks
+ */
+int av1_get_overlap_area(int row_a, int col_a, int row_b, int col_b, int width,
+ int height);
+
/*!\endcond */
#ifdef __cplusplus
} // extern "C"
diff --git a/test/tpl_model_test.cc b/test/tpl_model_test.cc
index a53ee43..8d5edcd 100644
--- a/test/tpl_model_test.cc
+++ b/test/tpl_model_test.cc
@@ -108,4 +108,36 @@
}
}
+TEST(TplModelTest, GetOverlapAreaHasOverlap) {
+ // The block a's area is [10, 17) x [18, 24).
+ // The block b's area is [8, 15) x [17, 23).
+ // The overlapping area between block a and block b is [10, 15) x [18, 23).
+ // Therefore, the size of the area is (15 - 10) * (23 - 18) = 25.
+ int row_a = 10;
+ int col_a = 18;
+ int row_b = 8;
+ int col_b = 17;
+ int height = 7;
+ int width = 6;
+ int overlap_area =
+ av1_get_overlap_area(row_a, col_a, row_b, col_b, width, height);
+ EXPECT_EQ(overlap_area, 25);
+}
+
+TEST(TplModelTest, GetOverlapAreaNoOverlap) {
+ // The block a's area is [10, 14) x [18, 22).
+ // The block b's area is [5, 9) x [5, 9).
+ // Threre is no overlapping area between block a and block b.
+ // Therefore, the return value should be zero.
+ int row_a = 10;
+ int col_a = 18;
+ int row_b = 5;
+ int col_b = 5;
+ int height = 4;
+ int width = 4;
+ int overlap_area =
+ av1_get_overlap_area(row_a, col_a, row_b, col_b, width, height);
+ EXPECT_EQ(overlap_area, 0);
+}
+
} // namespace