Add support for PSNR calculation based on stream bit-depth

This patch adds support to compute PSNR based on stream bit-depth when input bit-depth is lower than stream bit-depth.

Based on a similar change for AV1 (libaom): https://aomedia-review.googlesource.com/c/aom/+/116721

However, unlike AV1, we change the default to use stream bitdepth.

STATS_CHANGED possibly for rare cases.

For issue #344
diff --git a/aom/aom_encoder.h b/aom/aom_encoder.h
index 87e539c..36a1087 100644
--- a/aom/aom_encoder.h
+++ b/aom/aom_encoder.h
@@ -64,6 +64,8 @@
  *  The available flags are specified by AOM_CODEC_USE_* defines.
  */
 #define AOM_CODEC_USE_PSNR 0x10000 /**< Calculate PSNR on each frame */
+/*!\brief Use stream PSNR instead of input PSNR. */
+#define AOM_CODEC_USE_STREAM_PSNR 0x20000
 /*!\brief Print per frame stats. */
 #define AOM_CODEC_USE_PER_FRAME_STATS 0x80000 /**< Enable printing of stats */
 
@@ -132,8 +134,17 @@
       unsigned int samples[4]; /**< Number of samples, total/y/u/v */
       uint64_t sse[4];         /**< sum squared error, total/y/u/v */
       double psnr[4];          /**< PSNR, total/y/u/v */
-    } psnr;                    /**< data for PSNR packet */
-    aom_fixed_buf_t raw;       /**< data for arbitrary packets */
+      /*!\brief Number of samples, total/y/u/v when
+       * input bit-depth < stream bit-depth.*/
+      unsigned int samples_hbd[4];
+      /*!\brief sum squared error, total/y/u/v when
+       * input bit-depth < stream bit-depth.*/
+      uint64_t sse_hbd[4];
+      /*!\brief PSNR, total/y/u/v when
+       * input bit-depth < stream bit-depth.*/
+      double psnr_hbd[4];
+    } psnr;              /**< data for PSNR packet */
+    aom_fixed_buf_t raw; /**< data for arbitrary packets */
 
     /* This packet size is fixed to allow codecs to extend this
      * interface without having to manage storage for raw packets,
diff --git a/aom_dsp/psnr.c b/aom_dsp/psnr.c
index c2d144d..d651d01 100644
--- a/aom_dsp/psnr.c
+++ b/aom_dsp/psnr.c
@@ -222,9 +222,9 @@
   uint64_t total_sse = 0;
   uint32_t total_samples = 0;
 #if CONFIG_AV2CTC_PSNR_PEAK
-  const double peak = (double)(255 << (in_bit_depth - 8));
+  double peak = (double)(255 << (in_bit_depth - 8));
 #else
-  const double peak = (double)((1 << in_bit_depth) - 1);
+  double peak = (double)((1 << in_bit_depth) - 1);
 #endif  // CONFIG_AV2CTC_PSNR_PEAK
   const unsigned int input_shift = bit_depth - in_bit_depth;
 
@@ -253,4 +253,33 @@
   psnr->samples[0] = total_samples;
   psnr->psnr[0] =
       aom_sse_to_psnr((double)total_samples, peak, (double)total_sse);
+
+  // Compute PSNR based on stream bit depth
+  if (in_bit_depth < bit_depth) {
+#if CONFIG_AV2CTC_PSNR_PEAK
+    peak = (double)(255 << (bit_depth - 8));
+#else
+    peak = (double)((1 << bit_depth) - 1);
+#endif  // CONFIG_AV2CTC_PSNR_PEAK
+    total_sse = 0;
+    total_samples = 0;
+    for (i = 0; i < 3; ++i) {
+      const int w = widths[i];
+      const int h = heights[i];
+      const uint32_t samples = w * h;
+      uint64_t sse;
+      sse = highbd_get_sse(a->buffers[i], a_strides[i], b->buffers[i],
+                           b_strides[i], w, h);
+      psnr->sse_hbd[1 + i] = sse;
+      psnr->samples_hbd[1 + i] = samples;
+      psnr->psnr_hbd[1 + i] = aom_sse_to_psnr(samples, peak, (double)sse);
+      total_sse += sse;
+      total_samples += samples;
+    }
+
+    psnr->sse_hbd[0] = total_sse;
+    psnr->samples_hbd[0] = total_samples;
+    psnr->psnr_hbd[0] =
+        aom_sse_to_psnr((double)total_samples, peak, (double)total_sse);
+  }
 }
diff --git a/aom_dsp/psnr.h b/aom_dsp/psnr.h
index 67f89ba..0b8d9d8 100644
--- a/aom_dsp/psnr.h
+++ b/aom_dsp/psnr.h
@@ -22,9 +22,12 @@
 #endif
 
 typedef struct {
-  double psnr[4];       // total/y/u/v
-  uint64_t sse[4];      // total/y/u/v
-  uint32_t samples[4];  // total/y/u/v
+  double psnr[4];           // total/y/u/v
+  uint64_t sse[4];          // total/y/u/v
+  uint32_t samples[4];      // total/y/u/v
+  double psnr_hbd[4];       // total/y/u/v when input-bit-depth < bit-depth
+  uint64_t sse_hbd[4];      // total/y/u/v when input-bit-depth < bit-depth
+  uint32_t samples_hbd[4];  // total/y/u/v when input-bit-depth < bit-depth
 } PSNR_STATS;
 
 /*!\brief Converts SSE to PSNR
diff --git a/apps/aomenc.c b/apps/aomenc.c
index a86f6c8..a9f74b8 100644
--- a/apps/aomenc.c
+++ b/apps/aomenc.c
@@ -586,10 +586,10 @@
   FILE *file;
   struct rate_hist *rate_hist;
   struct WebmOutputContext webm_ctx;
-  uint64_t psnr_sse_total;
-  uint64_t psnr_samples_total;
-  double psnr_totals[4];
-  int psnr_count;
+  uint64_t psnr_sse_total[2];
+  uint64_t psnr_samples_total[2];
+  double psnr_totals[2][4];
+  int psnr_count[2];
   int counts[QINDEX_RANGE];
   aom_codec_ctx_t encoder;
   unsigned int frames_out;
@@ -751,6 +751,7 @@
   global->passes = 1;
   global->color_type = I420;
   global->csp = AOM_CSP_UNKNOWN;
+  global->show_psnr = 0;
   global->step_frames = 1;
 
   int cfg_included = 0;
@@ -810,7 +811,10 @@
         die("--step must be positive");
       }
     } else if (arg_match(&arg, &g_av1_codec_arg_defs.psnrarg, argi)) {
-      global->show_psnr = 1;
+      if (arg.val)
+        global->show_psnr = arg_parse_int(&arg);
+      else
+        global->show_psnr = 2;
     } else if (arg_match(&arg, &g_av1_codec_arg_defs.recontest, argi)) {
       global->test_decode = arg_parse_enum_or_int(&arg);
     } else if (arg_match(&arg, &g_av1_codec_arg_defs.framerate, argi)) {
@@ -1738,7 +1742,8 @@
   int i;
   int flags = 0;
 
-  flags |= global->show_psnr ? AOM_CODEC_USE_PSNR : 0;
+  flags |= (global->show_psnr >= 1) ? AOM_CODEC_USE_PSNR : 0;
+  flags |= (global->show_psnr == 2) ? AOM_CODEC_USE_STREAM_PSNR : 0;
   flags |= global->quiet ? 0 : AOM_CODEC_USE_PER_FRAME_STATS;
 
   /* Construct Encoder Context */
@@ -1989,15 +1994,25 @@
         break;
 
       case AOM_CODEC_PSNR_PKT:
-        if (global->show_psnr) {
+        if (global->show_psnr >= 1) {
           int i;
 
-          stream->psnr_sse_total += pkt->data.psnr.sse[0];
-          stream->psnr_samples_total += pkt->data.psnr.samples[0];
+          stream->psnr_sse_total[0] += pkt->data.psnr.sse[0];
+          stream->psnr_samples_total[0] += pkt->data.psnr.samples[0];
           for (i = 0; i < 4; i++) {
-            stream->psnr_totals[i] += pkt->data.psnr.psnr[i];
+            stream->psnr_totals[0][i] += pkt->data.psnr.psnr[i];
           }
-          stream->psnr_count++;
+          stream->psnr_count[0]++;
+
+          if (stream->config.cfg.g_input_bit_depth <
+              stream->config.cfg.g_bit_depth) {
+            stream->psnr_sse_total[1] += pkt->data.psnr.sse_hbd[0];
+            stream->psnr_samples_total[1] += pkt->data.psnr.samples_hbd[0];
+            for (i = 0; i < 4; i++) {
+              stream->psnr_totals[1][i] += pkt->data.psnr.psnr_hbd[i];
+            }
+            stream->psnr_count[1]++;
+          }
         }
 
         break;
@@ -2284,6 +2299,20 @@
                 "match input format.\n",
                 stream->config.cfg.g_profile);
       }
+      if ((global.show_psnr == 2) && (stream->config.cfg.g_input_bit_depth ==
+                                      stream->config.cfg.g_bit_depth)) {
+        fprintf(stderr,
+                "Warning: --psnr==2 and --psnr==1 will provide same "
+                "results when input bit-depth == stream bit-depth, "
+                "falling back to default psnr value\n");
+        global.show_psnr = 1;
+      }
+      if (global.show_psnr < 0 || global.show_psnr > 2) {
+        fprintf(stderr,
+                "Warning: --psnr can take only 0,1,2 as values,"
+                "falling back to default psnr value\n");
+        global.show_psnr = 1;
+      }
       /* Set limit */
       stream->config.cfg.g_limit = global.limit;
     }
@@ -2457,17 +2486,23 @@
         const double kbps = (bpf * (double)global.framerate.num /
                              (double)global.framerate.den) /
                             1000.0;
+        if (global.show_psnr >= 1) {
+          const int psnr_bit_depth = (global.show_psnr == 1)
+                                         ? stream->config.cfg.g_input_bit_depth
+                                         : stream->config.cfg.g_bit_depth;
+          const int psnr_idx = (global.show_psnr == 1) ? 0 : 1;
 #if CONFIG_AV2CTC_PSNR_PEAK
-        const double peak = (255 << (stream->config.cfg.g_input_bit_depth - 8));
+          const double peak = (255 << (psnr_bit_depth - 8));
 #else
-        const double peak = (1 << stream->config.cfg.g_input_bit_depth) - 1;
+          const double peak = (1 << psnr_bit_depth) - 1;
 #endif  // CONFIG_AV2CTC_PSNR_PEAK
-        const double ovpsnr = sse_to_psnr((double)stream->psnr_samples_total,
-                                          peak, (double)stream->psnr_sse_total);
-        double psnr[4] = { 0.0 };
-        if (global.show_psnr) {
+          const double ovpsnr =
+              sse_to_psnr((double)stream->psnr_samples_total[psnr_idx], peak,
+                          (double)stream->psnr_sse_total[psnr_idx]);
+          double psnr[4] = { 0.0 };
           for (int i = 0; i < 4; i++) {
-            psnr[i] = stream->psnr_totals[i] / stream->psnr_count;
+            psnr[i] =
+                stream->psnr_totals[psnr_idx][i] / stream->psnr_count[psnr_idx];
           }
           fprintf(stdout,
                   "\n         Bitrate(kbps)  |  PSNR(Y)  |  PSNR(U)  "
diff --git a/av1/arg_defs.c b/av1/arg_defs.c
index 11dc2fd..1d69249 100644
--- a/av1/arg_defs.c
+++ b/av1/arg_defs.c
@@ -165,7 +165,12 @@
   .good_dl = ARG_DEF(NULL, "good", 0, "Use Good Quality Deadline"),
   .quietarg = ARG_DEF("q", "quiet", 0, "Do not print encode progress"),
   .verbosearg = ARG_DEF("v", "verbose", 0, "Show encoder parameters"),
-  .psnrarg = ARG_DEF(NULL, "psnr", 0, "Show PSNR in status line"),
+  .psnrarg = ARG_DEF(
+      NULL, "psnr", -1,
+      "Show PSNR in status line"
+      "(0: Disable PSNR status line display, 1: PSNR calculated using input "
+      "bit-depth, 2: PSNR calculated using stream bit-depth (default)), "
+      "takes default option when arguments are not specified"),
   .use_cfg = ARG_DEF("c", "cfg", 1, "Config file to use"),
   .recontest = ARG_DEF_ENUM(NULL, "test-decode", 1,
                             "Test encode/decode mismatch", test_decode_enum),
diff --git a/av1/av1_cx_iface.c b/av1/av1_cx_iface.c
index 58f087b..68cf339 100644
--- a/av1/av1_cx_iface.c
+++ b/av1/av1_cx_iface.c
@@ -2868,17 +2868,10 @@
 }
 
 static void calculate_psnr(AV1_COMP *cpi, PSNR_STATS *psnr) {
-  int i;
-  PSNR_STATS stats;
-
   const uint32_t in_bit_depth = cpi->oxcf.input_cfg.input_bit_depth;
   const uint32_t bit_depth = cpi->td.mb.e_mbd.bd;
   aom_calc_highbd_psnr(cpi->unfiltered_source, &cpi->common.cur_frame->buf,
-                       &stats, bit_depth, in_bit_depth);
-
-  for (i = 0; i < 4; ++i) {
-    psnr->psnr[i] = stats.psnr[i];
-  }
+                       psnr, bit_depth, in_bit_depth);
 }
 
 static void report_stats(AV1_COMP *cpi, size_t frame_size, uint64_t cx_time) {
@@ -2892,9 +2885,10 @@
 
   for (int i = 0; i < 4; ++i) {
     psnr.psnr[i] = 0;
+    psnr.psnr_hbd[i] = 0;
   }
 
-  if (cpi->b_calculate_psnr) {
+  if (cpi->b_calculate_psnr >= 1) {
     calculate_psnr(cpi, &psnr);
   }
 
@@ -2909,7 +2903,8 @@
                              ? -1
                              : ref_poc[ref_idx];
     }
-    if (cpi->b_calculate_psnr) {
+    if (cpi->b_calculate_psnr >= 1) {
+      const bool use_hbd_psnr = (cpi->b_calculate_psnr == 2);
       fprintf(stdout,
               "POC:%6d [%s][Level:%d][Q:%3d]: %10" PRIu64
               " Bytes, "
@@ -2919,8 +2914,10 @@
               cm->cur_frame->absolute_poc,
               frameType[cm->current_frame.frame_type],
               cm->cur_frame->pyramid_level, base_qindex, (uint64_t)frame_size,
-              cx_time / 1000.0, psnr.psnr[1], psnr.psnr[2], psnr.psnr[3],
-              psnr.psnr[0]);
+              cx_time / 1000.0, use_hbd_psnr ? psnr.psnr_hbd[1] : psnr.psnr[1],
+              use_hbd_psnr ? psnr.psnr_hbd[2] : psnr.psnr[2],
+              use_hbd_psnr ? psnr.psnr_hbd[3] : psnr.psnr[3],
+              use_hbd_psnr ? psnr.psnr_hbd[0] : psnr.psnr[0]);
     } else {
       fprintf(stdout,
               "POC:%6d [%s][Level:%d][Q:%3d]: %10" PRIu64
@@ -3036,6 +3033,9 @@
     if (ctx->base.init_flags & AOM_CODEC_USE_PSNR) {
       cpi->b_calculate_psnr = 1;
     }
+    if (ctx->base.init_flags & AOM_CODEC_USE_STREAM_PSNR) {
+      cpi->b_calculate_psnr = 2;
+    }
     if (ctx->base.init_flags & AOM_CODEC_USE_PER_FRAME_STATS) {
       cpi->print_per_frame_stats = 1;
     }
diff --git a/av1/encoder/encoder.c b/av1/encoder/encoder.c
index 3898040..f98c287 100644
--- a/av1/encoder/encoder.c
+++ b/av1/encoder/encoder.c
@@ -1102,18 +1102,22 @@
   cpi->b_calculate_blockiness = 1;
   cpi->b_calculate_consistency = 1;
   cpi->total_inconsistency = 0;
-  cpi->psnr.worst = 100.0;
+  cpi->psnr[0].worst = 100.0;
+  cpi->psnr[1].worst = 100.0;
   cpi->worst_ssim = 100.0;
 
-  cpi->count = 0;
+  cpi->count[0] = 0;
+  cpi->count[1] = 0;
   cpi->bytes = 0;
 #if CONFIG_SPEED_STATS
   cpi->tx_search_count = 0;
 #endif  // CONFIG_SPEED_STATS
 
   if (cpi->b_calculate_psnr) {
-    cpi->total_sq_error = 0;
-    cpi->total_samples = 0;
+    cpi->total_sq_error[0] = 0;
+    cpi->total_samples[0] = 0;
+    cpi->total_sq_error[1] = 0;
+    cpi->total_samples[1] = 0;
     cpi->tot_recode_hits = 0;
     cpi->summed_quality = 0;
     cpi->summed_weights = 0;
@@ -1353,8 +1357,9 @@
       const double rate_err = ((100.0 * (dr - target_rate)) / target_rate);
 
       if (cpi->b_calculate_psnr) {
-        const double total_psnr = aom_sse_to_psnr(
-            (double)cpi->total_samples, peak, (double)cpi->total_sq_error);
+        const double total_psnr =
+            aom_sse_to_psnr((double)cpi->total_samples[0], peak,
+                            (double)cpi->total_sq_error[0]);
         const double total_ssim =
             100 * pow(cpi->summed_quality / cpi->summed_weights, 8.0);
         snprintf(headings, sizeof(headings),
@@ -1367,24 +1372,25 @@
                  "%7.3f\t%7.3f\t%7.3f\t%7.3f\t"
                  "%7.3f\t%7.3f\t%7.3f\t%7.3f\t"
                  "%7.3f\t%7.3f\t%7.3f",
-                 dr, cpi->psnr.stat[STAT_ALL] / cpi->count, total_psnr,
-                 cpi->psnr.stat[STAT_ALL] / cpi->count, total_psnr, total_ssim,
-                 total_ssim, cpi->fastssim.stat[STAT_ALL] / cpi->count,
-                 cpi->psnrhvs.stat[STAT_ALL] / cpi->count, cpi->psnr.worst,
-                 cpi->worst_ssim, cpi->fastssim.worst, cpi->psnrhvs.worst,
-                 cpi->psnr.stat[STAT_Y] / cpi->count,
-                 cpi->psnr.stat[STAT_U] / cpi->count,
-                 cpi->psnr.stat[STAT_V] / cpi->count);
+                 dr, cpi->psnr[0].stat[STAT_ALL] / cpi->count[0], total_psnr,
+                 cpi->psnr[0].stat[STAT_ALL] / cpi->count[0], total_psnr,
+                 total_ssim, total_ssim,
+                 cpi->fastssim.stat[STAT_ALL] / cpi->count[0],
+                 cpi->psnrhvs.stat[STAT_ALL] / cpi->count[0],
+                 cpi->psnr[0].worst, cpi->worst_ssim, cpi->fastssim.worst,
+                 cpi->psnrhvs.worst, cpi->psnr[0].stat[STAT_Y] / cpi->count[0],
+                 cpi->psnr[0].stat[STAT_U] / cpi->count[0],
+                 cpi->psnr[0].stat[STAT_V] / cpi->count[0]);
 
         if (cpi->b_calculate_blockiness) {
           SNPRINT(headings, "\t  Block\tWstBlck");
-          SNPRINT2(results, "\t%7.3f", cpi->total_blockiness / cpi->count);
+          SNPRINT2(results, "\t%7.3f", cpi->total_blockiness / cpi->count[0]);
           SNPRINT2(results, "\t%7.3f", cpi->worst_blockiness);
         }
 
         if (cpi->b_calculate_consistency) {
           double consistency =
-              aom_sse_to_psnr((double)cpi->total_samples, peak,
+              aom_sse_to_psnr((double)cpi->total_samples[0], peak,
                               (double)cpi->total_inconsistency);
 
           SNPRINT(headings, "\tConsist\tWstCons");
@@ -1392,16 +1398,43 @@
           SNPRINT2(results, "\t%7.3f", cpi->worst_consistency);
         }
 
-        SNPRINT(headings, "\t    Time\tRcErr\tAbsErr");
+        SNPRINT(headings, "\t   Time\tRcErr\tAbsErr");
         SNPRINT2(results, "\t%8.0f", total_encode_time);
-        SNPRINT2(results, "\t%7.2f", rate_err);
-        SNPRINT2(results, "\t%7.2f", fabs(rate_err));
+        SNPRINT2(results, " %7.2f", rate_err);
+        SNPRINT2(results, " %7.2f", fabs(rate_err));
 
-        fprintf(f, "%s\tAPsnr611\n", headings);
-        fprintf(f, "%s\t%7.3f\n", results,
-                (6 * cpi->psnr.stat[STAT_Y] + cpi->psnr.stat[STAT_U] +
-                 cpi->psnr.stat[STAT_V]) /
-                    (cpi->count * 8));
+        SNPRINT(headings, "\tAPsnr611");
+        SNPRINT2(results, " %7.3f",
+                 (6 * cpi->psnr[0].stat[STAT_Y] + cpi->psnr[0].stat[STAT_U] +
+                  cpi->psnr[0].stat[STAT_V]) /
+                     (cpi->count[0] * 8));
+
+        const uint32_t in_bit_depth = cpi->oxcf.input_cfg.input_bit_depth;
+        const uint32_t bit_depth = cpi->td.mb.e_mbd.bd;
+        if (in_bit_depth < bit_depth) {
+          const double peak_hbd = (double)((1 << bit_depth) - 1);
+          const double total_psnr_hbd =
+              aom_sse_to_psnr((double)cpi->total_samples[1], peak_hbd,
+                              (double)cpi->total_sq_error[1]);
+          SNPRINT(headings,
+                  "\t AVGPsnrH GLBPsnrH AVPsnrPH GLPsnrPH"
+                  " AVPsnrYH APsnrCbH APsnrCrH WstPsnrH");
+          SNPRINT2(results, "\t%7.3f",
+                   cpi->psnr[1].stat[STAT_ALL] / cpi->count[1]);
+          SNPRINT2(results, "  %7.3f", total_psnr_hbd);
+          SNPRINT2(results, "  %7.3f",
+                   cpi->psnr[1].stat[STAT_ALL] / cpi->count[1]);
+          SNPRINT2(results, "  %7.3f", total_psnr_hbd);
+          SNPRINT2(results, "  %7.3f",
+                   cpi->psnr[1].stat[STAT_Y] / cpi->count[1]);
+          SNPRINT2(results, "  %7.3f",
+                   cpi->psnr[1].stat[STAT_U] / cpi->count[1]);
+          SNPRINT2(results, "  %7.3f",
+                   cpi->psnr[1].stat[STAT_V] / cpi->count[1]);
+          SNPRINT2(results, "  %7.3f", cpi->psnr[1].worst);
+        }
+        fprintf(f, "%s\n", headings);
+        fprintf(f, "%s\n", results);
       }
 
       fclose(f);
@@ -1502,6 +1535,15 @@
     pkt.data.psnr.sse[i] = psnr.sse[i];
     pkt.data.psnr.psnr[i] = psnr.psnr[i];
   }
+
+  if (in_bit_depth < bit_depth) {
+    for (i = 0; i < 4; ++i) {
+      pkt.data.psnr.samples_hbd[i] = psnr.samples_hbd[i];
+      pkt.data.psnr.sse_hbd[i] = psnr.sse_hbd[i];
+      pkt.data.psnr.psnr_hbd[i] = psnr.psnr_hbd[i];
+    }
+  }
+
   pkt.kind = AOM_CODEC_PSNR_PKT;
   aom_codec_pkt_list_add(cpi->output_pkt_list, &pkt);
 }
@@ -4245,16 +4287,17 @@
     const YV12_BUFFER_CONFIG *recon = &cpi->common.cur_frame->buf;
     double y, u, v, frame_all;
 
-    cpi->count++;
+    cpi->count[0]++;
+    cpi->count[1]++;
     if (cpi->b_calculate_psnr) {
       PSNR_STATS psnr;
       double frame_ssim2 = 0.0, weight = 0.0;
       aom_clear_system_state();
       aom_calc_highbd_psnr(orig, recon, &psnr, bit_depth, in_bit_depth);
       adjust_image_stat(psnr.psnr[1], psnr.psnr[2], psnr.psnr[3], psnr.psnr[0],
-                        &cpi->psnr);
-      cpi->total_sq_error += psnr.sse[0];
-      cpi->total_samples += psnr.samples[0];
+                        &(cpi->psnr[0]));
+      cpi->total_sq_error[0] += psnr.sse[0];
+      cpi->total_samples[0] += psnr.samples[0];
       frame_ssim2 =
           aom_highbd_calc_ssim(orig, recon, &weight, bit_depth, in_bit_depth);
 
@@ -4262,6 +4305,14 @@
       cpi->summed_quality += frame_ssim2 * weight;
       cpi->summed_weights += weight;
 
+      // Compute PSNR based on stream bit depth
+      if (in_bit_depth < bit_depth) {
+        adjust_image_stat(psnr.psnr_hbd[1], psnr.psnr_hbd[2], psnr.psnr_hbd[3],
+                          psnr.psnr_hbd[0], &cpi->psnr[1]);
+        cpi->total_sq_error[1] += psnr.sse_hbd[0];
+        cpi->total_samples[1] += psnr.samples_hbd[0];
+      }
+
 #if 0
       {
         FILE *f = fopen("q_used.stt", "a");
@@ -4343,7 +4394,7 @@
         (*size > 0 && !is_stat_generation_stage(cpi) && cm->show_frame)) {
 #else
   // Note *size = 0 indicates a dropeed frame for which psnr is not calculated
-  if (cpi->b_calculate_psnr && *size > 0) {
+  if (cpi->b_calculate_psnr >= 1 && *size > 0) {
     if (cm->show_existing_frame ||
         (!is_stat_generation_stage(cpi) && cm->show_frame)) {
 #endif  // CONFIG_OUTPUT_FRAME_BASED_ON_ORDER_HINT
diff --git a/av1/encoder/encoder.h b/av1/encoder/encoder.h
index a195c9d..0a9ac05 100644
--- a/av1/encoder/encoder.h
+++ b/av1/encoder/encoder.h
@@ -2779,10 +2779,10 @@
   uint64_t time_receive_data;
   uint64_t time_compress_data;
 
-  int count;
-  uint64_t total_sq_error;
-  uint64_t total_samples;
-  ImageStat psnr;
+  int count[2];
+  uint64_t total_sq_error[2];
+  uint64_t total_samples[2];
+  ImageStat psnr[2];
 
   double total_blockiness;
   double worst_blockiness;
@@ -2807,7 +2807,8 @@
 #endif
 
   /*!
-   * Calculates PSNR on each frame when set to 1.
+   * Calculates PSNR on each frame when set to 1 or 2.
+   * Uses stream PSNR when set to 2.
    */
   int b_calculate_psnr;
 
@@ -3626,7 +3627,7 @@
 static AOM_INLINE int is_psnr_calc_enabled(const AV1_COMP *cpi) {
   const AV1_COMMON *const cm = &cpi->common;
 
-  return cpi->b_calculate_psnr && !is_stat_generation_stage(cpi) &&
+  return cpi->b_calculate_psnr >= 1 && !is_stat_generation_stage(cpi) &&
          cm->show_frame;
 }
 
diff --git a/av1/encoder/encoder_utils.c b/av1/encoder/encoder_utils.c
index ee501a0..aef304e 100644
--- a/av1/encoder/encoder_utils.c
+++ b/av1/encoder/encoder_utils.c
@@ -897,7 +897,10 @@
 
   if (pass != 1) return;
 
-  const double psnr_diff = psnr[1].psnr[0] - psnr[0].psnr[0];
+  const bool use_hbd_psnr = (cpi->b_calculate_psnr == 2);
+  const double psnr_diff = use_hbd_psnr
+                               ? psnr[1].psnr_hbd[0] - psnr[0].psnr_hbd[0]
+                               : psnr[1].psnr[0] - psnr[0].psnr[0];
 #if CONFIG_SCC_DETERMINATION
   // Calculate % of palette mode to be chosen in a frame from mode decision.
   const double palette_ratio =