Modify and simplify propagation mechanism

The propagation mechanism in TplFrameDepStatsPropagate()
is modified based on the paper
"A Temporal Dependency Model for Rate-Distortion Optimization in
Video Coding".

We also simplify TplFrameDepStatsPropagate() by converting
TplBlockStats to TplUnitDepStats before hand so that TplFrameStats
is not needed in TplFrameDepStatsPropagate().

We also make sure inter_cost <= intra_cost when converting
TplBlockStats to TplUnitDepStatsin by TplBlockStatsToDepStats()

Bug: b/221916304
Change-Id: Iafc2af5557fa99977ba82e1cf8b45b5863c78b6f
diff --git a/av1/ratectrl_qmode.cc b/av1/ratectrl_qmode.cc
index 44b6c15..1bb1393 100644
--- a/av1/ratectrl_qmode.cc
+++ b/av1/ratectrl_qmode.cc
@@ -24,8 +24,9 @@
 // This is used before division to ensure that the divisor isn't zero or
 // too close to zero.
 static double ModifyDivisor(double divisor) {
-  const double kEpsilon = 0.000001;
-  return (divisor < 0 ? divisor - kEpsilon : divisor + kEpsilon);
+  const double kEpsilon = 0.0000001;
+  return (divisor < 0 ? std::min(divisor, -kEpsilon)
+                      : std::max(divisor, kEpsilon));
 }
 
 GopFrame GopFrameInvalid() {
@@ -693,11 +694,25 @@
       frame_width / min_block_size + !!(frame_width % min_block_size);
   TplFrameDepStats frame_dep_stats;
   frame_dep_stats.unit_size = min_block_size;
-  frame_dep_stats.unit_stats = std::vector<std::vector<double>>(
-      unit_rows, std::vector<double>(unit_cols, 0));
+  frame_dep_stats.unit_stats = std::vector<std::vector<TplUnitDepStats>>(
+      unit_rows, std::vector<TplUnitDepStats>(unit_cols));
   return frame_dep_stats;
 }
 
+TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
+                                        int unit_count) {
+  TplUnitDepStats dep_stats = {};
+  dep_stats.intra_cost = block_stats.intra_cost * 1.0 / unit_count;
+  dep_stats.inter_cost = block_stats.inter_cost * 1.0 / unit_count;
+  // In rare case, inter_cost may be greater than intra_cost.
+  // If so, we need to modify inter_cost such that inter_cost <= intra_cost
+  // because it is required by GetPropagationFraction()
+  dep_stats.inter_cost = std::min(dep_stats.intra_cost, dep_stats.inter_cost);
+  dep_stats.mv = block_stats.mv;
+  dep_stats.ref_frame_index = block_stats.ref_frame_index;
+  return dep_stats;
+}
+
 TplFrameDepStats CreateTplFrameDepStatsWithoutPropagation(
     const TplFrameStats &frame_stats) {
   const int min_block_size = frame_stats.min_block_size;
@@ -709,25 +724,25 @@
     const int unit_count = block_unit_rows * block_unit_cols;
     const int block_unit_row = block_stats.row / min_block_size;
     const int block_unit_col = block_stats.col / min_block_size;
-    const double cost_diff =
-        (block_stats.inter_cost - block_stats.intra_cost) * 1.0 / unit_count;
+    TplUnitDepStats unit_stats =
+        TplBlockStatsToDepStats(block_stats, unit_count);
     for (int r = 0; r < block_unit_rows; r++) {
       for (int c = 0; c < block_unit_cols; c++) {
         frame_dep_stats.unit_stats[block_unit_row + r][block_unit_col + c] =
-            cost_diff;
+            unit_stats;
       }
     }
   }
   return frame_dep_stats;
 }
 
-int GetRefCodingIdxList(const TplBlockStats &block_stats,
+int GetRefCodingIdxList(const TplUnitDepStats &unit_dep_stats,
                         const RefFrameTable &ref_frame_table,
                         int *ref_coding_idx_list) {
   int ref_frame_count = 0;
   for (int i = 0; i < kBlockRefCount; ++i) {
     ref_coding_idx_list[i] = -1;
-    int ref_frame_index = block_stats.ref_frame_index[i];
+    int ref_frame_index = unit_dep_stats.ref_frame_index[i];
     if (ref_frame_index != -1) {
       ref_coding_idx_list[i] = ref_frame_table[ref_frame_index].coding_idx;
       ref_frame_count++;
@@ -747,18 +762,27 @@
   return 0;
 }
 
-double TplFrameStatsAccumulate(const TplFrameStats &frame_stats) {
-  double ref_sum_cost_diff = 0;
-  for (auto &block_stats : frame_stats.block_stats_list) {
-    ref_sum_cost_diff += block_stats.inter_cost - block_stats.intra_cost;
+// TODO(angiebird): Merge TplFrameDepStatsAccumulateIntraCost and
+// TplFrameDepStatsAccumulate.
+double TplFrameDepStatsAccumulateIntraCost(
+    const TplFrameDepStats &frame_dep_stats) {
+  auto getIntraCost = [](double sum, const TplUnitDepStats &unit) {
+    return sum + unit.intra_cost;
+  };
+  double sum = 0;
+  for (const auto &row : frame_dep_stats.unit_stats) {
+    sum = std::accumulate(row.begin(), row.end(), sum, getIntraCost);
   }
-  return ref_sum_cost_diff;
+  return sum;
 }
 
 double TplFrameDepStatsAccumulate(const TplFrameDepStats &frame_dep_stats) {
+  auto getOverallCost = [](double sum, const TplUnitDepStats &unit) {
+    return sum + unit.propagation_cost + unit.intra_cost;
+  };
   double sum = 0;
   for (const auto &row : frame_dep_stats.unit_stats) {
-    sum = std::accumulate(row.begin(), row.end(), sum);
+    sum = std::accumulate(row.begin(), row.end(), sum, getOverallCost);
   }
   return sum;
 }
@@ -774,57 +798,66 @@
   return fullpel_value;
 }
 
-void TplFrameDepStatsPropagate(const TplFrameStats &frame_stats,
+double GetPropagationFraction(const TplUnitDepStats &unit_dep_stats) {
+  assert(unit_dep_stats.intra_cost >= unit_dep_stats.inter_cost);
+  return (unit_dep_stats.intra_cost - unit_dep_stats.inter_cost) /
+         ModifyDivisor(unit_dep_stats.intra_cost);
+}
+
+void TplFrameDepStatsPropagate(int coding_idx,
                                const RefFrameTable &ref_frame_table,
                                TplGopDepStats *tpl_gop_dep_stats) {
-  const int min_block_size = frame_stats.min_block_size;
+  assert(!tpl_gop_dep_stats->frame_dep_stats_list.empty());
+  TplFrameDepStats *frame_dep_stats =
+      &tpl_gop_dep_stats->frame_dep_stats_list[coding_idx];
+
+  const int unit_size = frame_dep_stats->unit_size;
   const int frame_unit_rows =
-      frame_stats.frame_height / frame_stats.min_block_size;
+      static_cast<int>(frame_dep_stats->unit_stats.size());
   const int frame_unit_cols =
-      frame_stats.frame_width / frame_stats.min_block_size;
-  for (const TplBlockStats &block_stats : frame_stats.block_stats_list) {
-    int ref_coding_idx_list[kBlockRefCount] = { -1, -1 };
-    int ref_frame_count =
-        GetRefCodingIdxList(block_stats, ref_frame_table, ref_coding_idx_list);
-    if (ref_frame_count > 0) {
-      double propagation_ratio = 1.0 / ref_frame_count;
+      static_cast<int>(frame_dep_stats->unit_stats[0].size());
+  for (int unit_row = 0; unit_row < frame_unit_rows; ++unit_row) {
+    for (int unit_col = 0; unit_col < frame_unit_cols; ++unit_col) {
+      TplUnitDepStats &unit_dep_stats =
+          frame_dep_stats->unit_stats[unit_row][unit_col];
+      int ref_coding_idx_list[kBlockRefCount] = { -1, -1 };
+      int ref_frame_count = GetRefCodingIdxList(unit_dep_stats, ref_frame_table,
+                                                ref_coding_idx_list);
+      if (ref_frame_count == 0) continue;
       for (int i = 0; i < kBlockRefCount; ++i) {
-        if (ref_coding_idx_list[i] != -1) {
-          auto &ref_frame_dep_stats =
-              tpl_gop_dep_stats->frame_dep_stats_list[ref_coding_idx_list[i]];
-          const auto &mv = block_stats.mv[i];
-          const int mv_row = GetFullpelValue(mv.row, mv.subpel_bits);
-          const int mv_col = GetFullpelValue(mv.col, mv.subpel_bits);
-          const int block_unit_rows = block_stats.height / min_block_size;
-          const int block_unit_cols = block_stats.width / min_block_size;
-          const int unit_count = block_unit_rows * block_unit_cols;
-          const double cost_diff =
-              (block_stats.inter_cost - block_stats.intra_cost) * 1.0 /
-              unit_count;
-          for (int r = 0; r < block_unit_rows; r++) {
-            for (int c = 0; c < block_unit_cols; c++) {
-              const int ref_block_row =
-                  block_stats.row + r * min_block_size + mv_row;
-              const int ref_block_col =
-                  block_stats.col + c * min_block_size + mv_col;
-              const int ref_unit_row_low = ref_block_row / min_block_size;
-              const int ref_unit_col_low = ref_block_col / min_block_size;
-              for (int j = 0; j < 2; ++j) {
-                for (int k = 0; k < 2; ++k) {
-                  const int unit_row = ref_unit_row_low + j;
-                  const int unit_col = ref_unit_col_low + k;
-                  if (unit_row >= 0 && unit_row < frame_unit_rows &&
-                      unit_col >= 0 && unit_col < frame_unit_cols) {
-                    const int overlap_area = GetBlockOverlapArea(
-                        unit_row * min_block_size, unit_col * min_block_size,
-                        ref_block_row, ref_block_col, min_block_size);
-                    const double overlap_ratio =
-                        overlap_area * 1.0 / (min_block_size * min_block_size);
-                    ref_frame_dep_stats.unit_stats[unit_row][unit_col] +=
-                        cost_diff * overlap_ratio * propagation_ratio;
-                  }
-                }
-              }
+        if (ref_coding_idx_list[i] == -1) continue;
+        TplFrameDepStats &ref_frame_dep_stats =
+            tpl_gop_dep_stats->frame_dep_stats_list[ref_coding_idx_list[i]];
+        const auto &mv = unit_dep_stats.mv[i];
+        const int mv_row = GetFullpelValue(mv.row, mv.subpel_bits);
+        const int mv_col = GetFullpelValue(mv.col, mv.subpel_bits);
+        const int ref_pixel_r = unit_row * unit_size + mv_row;
+        const int ref_pixel_c = unit_col * unit_size + mv_col;
+        const int ref_unit_row_low =
+            (unit_row * unit_size + mv_row) / unit_size;
+        const int ref_unit_col_low =
+            (unit_col * unit_size + mv_col) / unit_size;
+        for (int j = 0; j < 2; ++j) {
+          for (int k = 0; k < 2; ++k) {
+            const int ref_unit_row = ref_unit_row_low + j;
+            const int ref_unit_col = ref_unit_col_low + k;
+            if (ref_unit_row >= 0 && ref_unit_row < frame_unit_rows &&
+                ref_unit_col >= 0 && ref_unit_col < frame_unit_cols) {
+              const int overlap_area = GetBlockOverlapArea(
+                  ref_pixel_r, ref_pixel_c, ref_unit_row * unit_size,
+                  ref_unit_col * unit_size, unit_size);
+              const double overlap_ratio =
+                  overlap_area * 1.0 / (unit_size * unit_size);
+              const double propagation_fraction =
+                  GetPropagationFraction(unit_dep_stats);
+              const double propagation_ratio =
+                  1.0 / ref_frame_count * overlap_ratio * propagation_fraction;
+              TplUnitDepStats &ref_unit_stats =
+                  ref_frame_dep_stats.unit_stats[ref_unit_row][ref_unit_col];
+              ref_unit_stats.propagation_cost +=
+                  (unit_dep_stats.intra_cost +
+                   unit_dep_stats.propagation_cost) *
+                  propagation_ratio;
             }
           }
         }
@@ -867,8 +900,7 @@
     auto &ref_frame_table = ref_frame_table_list[coding_idx];
     // TODO(angiebird): Handle/test the case where reference frame
     // is in the previous GOP
-    TplFrameDepStatsPropagate(tpl_gop_stats.frame_stats_list[coding_idx],
-                              ref_frame_table, &tpl_gop_dep_stats);
+    TplFrameDepStatsPropagate(coding_idx, ref_frame_table, &tpl_gop_dep_stats);
   }
   return tpl_gop_dep_stats;
 }
@@ -902,15 +934,12 @@
   const int frame_count =
       static_cast<int>(tpl_gop_stats.frame_stats_list.size());
   for (int i = 0; i < frame_count; i++) {
-    const TplFrameStats &frame_stats = tpl_gop_stats.frame_stats_list[i];
     const TplFrameDepStats &frame_dep_stats =
         gop_dep_stats.frame_dep_stats_list[i];
     const double cost_without_propagation =
-        TplFrameStatsAccumulate(frame_stats);
+        TplFrameDepStatsAccumulateIntraCost(frame_dep_stats);
     const double cost_with_propagation =
         TplFrameDepStatsAccumulate(frame_dep_stats);
-    // TODO(angiebird): This part is still a draft. Check whether this makes
-    // sense mathematically.
     const double frame_importance =
         cost_with_propagation / cost_without_propagation;
     // Imitate the behavior of av1_tpl_get_qstep_ratio()
diff --git a/av1/ratectrl_qmode.h b/av1/ratectrl_qmode.h
index eaab68f..7602169 100644
--- a/av1/ratectrl_qmode.h
+++ b/av1/ratectrl_qmode.h
@@ -25,9 +25,17 @@
 constexpr int kMinIntervalToAddArf = 3;
 constexpr int kMinArfInterval = (kMinIntervalToAddArf + 1) / 2;
 
+struct TplUnitDepStats {
+  double propagation_cost;
+  double intra_cost;
+  double inter_cost;
+  std::array<MotionVector, kBlockRefCount> mv;
+  std::array<int, kBlockRefCount> ref_frame_index;
+};
+
 struct TplFrameDepStats {
   int unit_size;  // equivalent to min_block_size
-  std::vector<std::vector<double>> unit_stats;
+  std::vector<std::vector<TplUnitDepStats>> unit_stats;
 };
 
 struct TplGopDepStats {
@@ -65,16 +73,20 @@
 TplFrameDepStats CreateTplFrameDepStats(int frame_height, int frame_width,
                                         int min_block_size);
 
+TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
+                                        int unit_count);
+
 TplFrameDepStats CreateTplFrameDepStatsWithoutPropagation(
     const TplFrameStats &frame_stats);
 
 std::vector<int> GetKeyFrameList(const FirstpassInfo &first_pass_info);
 
-double TplFrameStatsAccumulate(const TplFrameStats &frame_stats);
+double TplFrameDepStatsAccumulateIntraCost(
+    const TplFrameDepStats &frame_dep_stats);
 
 double TplFrameDepStatsAccumulate(const TplFrameDepStats &frame_dep_stats);
 
-void TplFrameDepStatsPropagate(const TplFrameStats &frame_stats,
+void TplFrameDepStatsPropagate(int coding_idx,
                                const RefFrameTable &ref_frame_table,
                                TplGopDepStats *tpl_gop_dep_stats);
 
diff --git a/test/ratectrl_qmode_test.cc b/test/ratectrl_qmode_test.cc
index 91cb008..62239f6 100644
--- a/test/ratectrl_qmode_test.cc
+++ b/test/ratectrl_qmode_test.cc
@@ -28,6 +28,8 @@
 using ::testing::Field;
 using ::testing::Return;
 
+constexpr double kErrorEpsilon = 0.000001;
+
 void TestGopDisplayOrder(const GopStruct &gop_struct) {
   // Test whether show frames' order indices are sequential
   int expected_order_idx = 0;
@@ -161,15 +163,14 @@
 }
 
 static TplBlockStats CreateToyTplBlockStats(int h, int w, int r, int c,
-                                            int cost_diff) {
+                                            int intra_cost, int inter_cost) {
   TplBlockStats tpl_block_stats = {};
   tpl_block_stats.height = h;
   tpl_block_stats.width = w;
   tpl_block_stats.row = r;
   tpl_block_stats.col = c;
-  // A random trick that makes inter_cost - intra_cost = cost_diff;
-  tpl_block_stats.intra_cost = cost_diff / 2;
-  tpl_block_stats.inter_cost = cost_diff + cost_diff / 2;
+  tpl_block_stats.intra_cost = intra_cost;
+  tpl_block_stats.inter_cost = inter_cost;
   tpl_block_stats.ref_frame_index = { -1, -1 };
   return tpl_block_stats;
 }
@@ -191,9 +192,9 @@
         for (int v = 0; v * w < max_w; ++v) {
           int r = max_h * i + h * u;
           int c = max_w * j + w * v;
-          int cost_diff = std::rand() % 16;
+          int intra_cost = std::rand() % 16;
           TplBlockStats block_stats =
-              CreateToyTplBlockStats(h, w, r, c, cost_diff);
+              CreateToyTplBlockStats(h, w, r, c, intra_cost, 0);
           frame_stats.block_stats_list.push_back(block_stats);
         }
       }
@@ -235,6 +236,14 @@
   return { row, col, 0 };
 }
 
+double TplFrameStatsAccumulateIntraCost(const TplFrameStats &frame_stats) {
+  double sum = 0;
+  for (auto &block_stats : frame_stats.block_stats_list) {
+    sum += block_stats.intra_cost;
+  }
+  return sum;
+}
+
 TEST(RateControlQModeTest, CreateTplFrameDepStats) {
   TplFrameStats frame_stats = CreateToyTplFrameStatsWithDiffSizes(8, 16);
   TplFrameDepStats frame_dep_stats =
@@ -244,10 +253,12 @@
   const int unit_cols = static_cast<int>(frame_dep_stats.unit_stats[0].size());
   EXPECT_EQ(frame_stats.frame_height, unit_rows * frame_dep_stats.unit_size);
   EXPECT_EQ(frame_stats.frame_width, unit_cols * frame_dep_stats.unit_size);
-  const double sum_cost_diff = TplFrameDepStatsAccumulate(frame_dep_stats);
+  const double intra_cost_sum =
+      TplFrameDepStatsAccumulateIntraCost(frame_dep_stats);
 
-  const double ref_sum_cost_diff = TplFrameStatsAccumulate(frame_stats);
-  EXPECT_NEAR(sum_cost_diff, ref_sum_cost_diff, 0.0000001);
+  const double expected_intra_cost_sum =
+      TplFrameStatsAccumulateIntraCost(frame_stats);
+  EXPECT_NEAR(intra_cost_sum, expected_intra_cost_sum, kErrorEpsilon);
 }
 
 TEST(RateControlQModeTest, GetBlockOverlapArea) {
@@ -265,6 +276,20 @@
   }
 }
 
+TEST(RateControlQModeTest, TplBlockStatsToDepStats) {
+  const int intra_cost = 100;
+  const int inter_cost = 120;
+  const int unit_count = 2;
+  TplBlockStats block_stats =
+      CreateToyTplBlockStats(8, 4, 0, 0, intra_cost, inter_cost);
+  TplUnitDepStats unit_stats = TplBlockStatsToDepStats(block_stats, unit_count);
+  double expected_intra_cost = intra_cost * 1.0 / unit_count;
+  EXPECT_NEAR(unit_stats.intra_cost, expected_intra_cost, kErrorEpsilon);
+  // When inter_cost >= intra_cost in block_stats, in unit_stats,
+  // the inter_cost will be modified so that it's upper-bounded by intra_cost.
+  EXPECT_LE(unit_stats.inter_cost, unit_stats.intra_cost);
+}
+
 TEST(RateControlQModeTest, TplFrameDepStatsPropagateSingleZeroMotion) {
   // cur frame with coding_idx 1 use ref frame with coding_idx 0
   const std::array<int, kBlockRefCount> ref_frame_index = { 0, -1 };
@@ -285,19 +310,20 @@
   gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats1);
 
   const RefFrameTable ref_frame_table = CreateToyRefFrameTable(frame_count);
-  TplFrameDepStatsPropagate(frame_stats, ref_frame_table, &gop_dep_stats);
+  TplFrameDepStatsPropagate(/*coding_idx=*/1, ref_frame_table, &gop_dep_stats);
 
   // cur frame with coding_idx 1
-  const double ref_sum_cost_diff = TplFrameStatsAccumulate(frame_stats);
+  const double expected_propagation_sum =
+      TplFrameStatsAccumulateIntraCost(frame_stats);
 
   // ref frame with coding_idx 0
-  const double sum_cost_diff =
+  const double propagation_sum =
       TplFrameDepStatsAccumulate(gop_dep_stats.frame_dep_stats_list[0]);
 
-  // The sum_cost_diff between coding_idx 0 and coding_idx 1 should be equal
+  // The propagation_sum between coding_idx 0 and coding_idx 1 should be equal
   // because every block in cur frame has zero motion, use ref frame with
   // coding_idx 0 for prediction, and ref frame itself is empty.
-  EXPECT_NEAR(sum_cost_diff, ref_sum_cost_diff, 0.0000001);
+  EXPECT_NEAR(propagation_sum, expected_propagation_sum, kErrorEpsilon);
 }
 
 TEST(RateControlQModeTest, TplFrameDepStatsPropagateCompoundZeroMotion) {
@@ -326,20 +352,20 @@
   gop_dep_stats.frame_dep_stats_list.push_back(frame_dep_stats2);
 
   const RefFrameTable ref_frame_table = CreateToyRefFrameTable(frame_count);
-  TplFrameDepStatsPropagate(frame_stats, ref_frame_table, &gop_dep_stats);
+  TplFrameDepStatsPropagate(/*coding_idx=*/2, ref_frame_table, &gop_dep_stats);
 
   // cur frame with coding_idx 1
-  const double ref_sum_cost_diff = TplFrameStatsAccumulate(frame_stats);
+  const double expected_ref_sum = TplFrameStatsAccumulateIntraCost(frame_stats);
 
   // ref frame with coding_idx 0
-  const double sum_cost_diff0 =
+  const double cost_sum0 =
       TplFrameDepStatsAccumulate(gop_dep_stats.frame_dep_stats_list[0]);
-  EXPECT_NEAR(sum_cost_diff0, ref_sum_cost_diff * 0.5, 0.0000001);
+  EXPECT_NEAR(cost_sum0, expected_ref_sum * 0.5, kErrorEpsilon);
 
   // ref frame with coding_idx 1
-  const double sum_cost_diff1 =
+  const double cost_sum1 =
       TplFrameDepStatsAccumulate(gop_dep_stats.frame_dep_stats_list[1]);
-  EXPECT_NEAR(sum_cost_diff1, ref_sum_cost_diff * 0.5, 0.0000001);
+  EXPECT_NEAR(cost_sum1, expected_ref_sum * 0.5, kErrorEpsilon);
 }
 
 TEST(RateControlQModeTest, TplFrameDepStatsPropagateSingleWithMotion) {
@@ -371,7 +397,7 @@
       CreateTplFrameDepStatsWithoutPropagation(frame_stats));
 
   const RefFrameTable ref_frame_table = CreateToyRefFrameTable(frame_count);
-  TplFrameDepStatsPropagate(frame_stats, ref_frame_table, &gop_dep_stats);
+  TplFrameDepStatsPropagate(/*coding_idx=*/1, ref_frame_table, &gop_dep_stats);
 
   const auto &dep_stats0 = gop_dep_stats.frame_dep_stats_list[0];
   const auto &dep_stats1 = gop_dep_stats.frame_dep_stats_list[1];
@@ -380,17 +406,22 @@
   for (int r = 0; r < unit_rows; ++r) {
     for (int c = 0; c < unit_cols; ++c) {
       double ref_value = 0;
-      ref_value += (1 - r_ratio) * (1 - c_ratio) * dep_stats1.unit_stats[r][c];
+      ref_value += (1 - r_ratio) * (1 - c_ratio) *
+                   dep_stats1.unit_stats[r][c].intra_cost;
       if (r - 1 >= 0) {
-        ref_value += r_ratio * (1 - c_ratio) * dep_stats1.unit_stats[r - 1][c];
+        ref_value += r_ratio * (1 - c_ratio) *
+                     dep_stats1.unit_stats[r - 1][c].intra_cost;
       }
       if (c - 1 >= 0) {
-        ref_value += (1 - r_ratio) * c_ratio * dep_stats1.unit_stats[r][c - 1];
+        ref_value += (1 - r_ratio) * c_ratio *
+                     dep_stats1.unit_stats[r][c - 1].intra_cost;
       }
       if (r - 1 >= 0 && c - 1 >= 0) {
-        ref_value += r_ratio * c_ratio * dep_stats1.unit_stats[r - 1][c - 1];
+        ref_value +=
+            r_ratio * c_ratio * dep_stats1.unit_stats[r - 1][c - 1].intra_cost;
       }
-      EXPECT_NEAR(dep_stats0.unit_stats[r][c], ref_value, 0.0000001);
+      EXPECT_NEAR(dep_stats0.unit_stats[r][c].propagation_cost, ref_value,
+                  kErrorEpsilon);
     }
   }
 }
@@ -412,14 +443,15 @@
   const TplGopDepStats &gop_dep_stats =
       ComputeTplGopDepStats(tpl_gop_stats, ref_frame_table_list);
 
-  double ref_sum = 0;
+  double expected_sum = 0;
   for (int i = 2; i >= 0; i--) {
-    // Due to the linear propagation with zero motion, we can add the
-    // frame_stats value and use it as reference sum for dependency stats
-    ref_sum += TplFrameStatsAccumulate(tpl_gop_stats.frame_stats_list[i]);
+    // Due to the linear propagation with zero motion, we can accumulate the
+    // frame_stats intra_cost and use it as expected sum for dependency stats
+    expected_sum +=
+        TplFrameStatsAccumulateIntraCost(tpl_gop_stats.frame_stats_list[i]);
     const double sum =
         TplFrameDepStatsAccumulate(gop_dep_stats.frame_dep_stats_list[i]);
-    EXPECT_NEAR(sum, ref_sum, 0.0000001);
+    EXPECT_NEAR(sum, expected_sum, kErrorEpsilon);
     break;
   }
 }