Add encoder/bitstream support for SKIP_SGR

The encoder can now make use of SGR filters with r == 0 when
SKIP_SGR == 1. If r == 0 for a filter, no blending coefficient
for that filter is written to/read from the bitstream.

Change-Id: I8496b87a7fa7b29f5ee9e7687bd117f93e90e649
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 8491806..5d1d269 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -179,11 +179,20 @@
                                     int src_stride, const uint8_t *dat8,
                                     int dat_stride, int use_highbitdepth,
                                     int32_t *flt1, int flt1_stride,
-                                    int32_t *flt2, int flt2_stride, int *xqd) {
+                                    int32_t *flt2, int flt2_stride, int *xqd
+#if CONFIG_SKIP_SGR
+                                    ,
+                                    const sgr_params_type *params
+#endif  // CONFIG_SKIP_SGR
+) {
   int i, j;
   int64_t err = 0;
   int xq[2];
+#if CONFIG_SKIP_SGR
+  decode_xq(xqd, xq, params);
+#else   // CONFIG_SKIP_SGR
   decode_xq(xqd, xq);
+#endif  // CONFIG_SKIP_SGR
   if (!use_highbitdepth) {
     const uint8_t *src = src8;
     const uint8_t *dat = dat8;
@@ -191,9 +200,15 @@
       for (j = 0; j < width; ++j) {
         const int32_t u =
             (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
+#if CONFIG_SKIP_SGR
+        int32_t v = u << SGRPROJ_PRJ_BITS;
+        if (params->r1 > 0) v += xq[0] * (flt1[i * flt1_stride + j] - u);
+        if (params->r2 > 0) v += xq[1] * (flt2[i * flt2_stride + j] - u);
+#else   // CONFIG_SKIP_SGR
         const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
         const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
         const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
+#endif  // CONFIG_SKIP_SGR
         const int32_t e =
             ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
             src[i * src_stride + j];
@@ -207,9 +222,15 @@
       for (j = 0; j < width; ++j) {
         const int32_t u =
             (int32_t)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
+#if CONFIG_SKIP_SGR
+        int32_t v = u << SGRPROJ_PRJ_BITS;
+        if (params->r1 > 0) v += xq[0] * (flt1[i * flt1_stride + j] - u);
+        if (params->r2 > 0) v += xq[1] * (flt2[i * flt2_stride + j] - u);
+#else   // CONFIG_SKIP_SGR
         const int32_t f1 = (int32_t)flt1[i * flt1_stride + j] - u;
         const int32_t f2 = (int32_t)flt2[i * flt2_stride + j] - u;
         const int32_t v = xq[0] * f1 + xq[1] * f2 + (u << SGRPROJ_PRJ_BITS);
+#endif  // CONFIG_SKIP_SGR
         const int32_t e =
             ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) -
             src[i * src_stride + j];
@@ -224,10 +245,21 @@
 static int64_t finer_search_pixel_proj_error(
     const uint8_t *src8, int width, int height, int src_stride,
     const uint8_t *dat8, int dat_stride, int use_highbitdepth, int32_t *flt1,
-    int flt1_stride, int32_t *flt2, int flt2_stride, int start_step, int *xqd) {
+    int flt1_stride, int32_t *flt2, int flt2_stride, int start_step, int *xqd
+#if CONFIG_SKIP_SGR
+    ,
+    const sgr_params_type *params
+#endif  // CONFIG_SKIP_SGR
+) {
+#if CONFIG_SKIP_SGR
+  int64_t err = get_pixel_proj_error(
+      src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth, flt1,
+      flt1_stride, flt2, flt2_stride, xqd, params);
+#else   // CONFIG_SKIP_SGR
   int64_t err = get_pixel_proj_error(src8, width, height, src_stride, dat8,
                                      dat_stride, use_highbitdepth, flt1,
                                      flt1_stride, flt2, flt2_stride, xqd);
+#endif  // CONFIG_SKIP_SGR
   (void)start_step;
 #if USE_SGRPROJ_REFINEMENT_SEARCH
   int64_t err2;
@@ -235,13 +267,23 @@
   int tap_max[] = { SGRPROJ_PRJ_MAX0, SGRPROJ_PRJ_MAX1 };
   for (int s = start_step; s >= 1; s >>= 1) {
     for (int p = 0; p < 2; ++p) {
+#if CONFIG_SKIP_SGR
+      if ((params->r1 == 0 && p == 0) || (params->r2 == 0 && p == 1)) continue;
+#endif
       int skip = 0;
       do {
         if (xqd[p] - s >= tap_min[p]) {
           xqd[p] -= s;
+#if CONFIG_SKIP_SGR
+          err2 =
+              get_pixel_proj_error(src8, width, height, src_stride, dat8,
+                                   dat_stride, use_highbitdepth, flt1,
+                                   flt1_stride, flt2, flt2_stride, xqd, params);
+#else   // CONFIG_SKIP_SGR
           err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
                                       dat_stride, use_highbitdepth, flt1,
                                       flt1_stride, flt2, flt2_stride, xqd);
+#endif  // CONFIG_SKIP_SGR
           if (err2 > err) {
             xqd[p] += s;
           } else {
@@ -257,9 +299,16 @@
       do {
         if (xqd[p] + s <= tap_max[p]) {
           xqd[p] += s;
+#if CONFIG_SKIP_SGR
+          err2 =
+              get_pixel_proj_error(src8, width, height, src_stride, dat8,
+                                   dat_stride, use_highbitdepth, flt1,
+                                   flt1_stride, flt2, flt2_stride, xqd, params);
+#else   // CONFIG_SKIP_SGR
           err2 = get_pixel_proj_error(src8, width, height, src_stride, dat8,
                                       dat_stride, use_highbitdepth, flt1,
                                       flt1_stride, flt2, flt2_stride, xqd);
+#endif  // CONFIG_SKIP_SGR
           if (err2 > err) {
             xqd[p] -= s;
           } else {
@@ -280,7 +329,12 @@
                               int src_stride, const uint8_t *dat8,
                               int dat_stride, int use_highbitdepth,
                               int32_t *flt1, int flt1_stride, int32_t *flt2,
-                              int flt2_stride, int *xq) {
+                              int flt2_stride, int *xq
+#if CONFIG_SKIP_SGR
+                              ,
+                              const sgr_params_type *params
+#endif  // CONFIG_SKIP_SGR
+) {
   int i, j;
   double H[2][2] = { { 0, 0 }, { 0, 0 } };
   double C[2] = { 0, 0 };
@@ -301,8 +355,15 @@
         const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
         const double s =
             (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
+#if CONFIG_SKIP_SGR
+        const double f1 =
+            (params->r1 > 0) ? (double)flt1[i * flt1_stride + j] - u : 0;
+        const double f2 =
+            (params->r2 > 0) ? (double)flt2[i * flt2_stride + j] - u : 0;
+#else   // CONFIG_SKIP_SGR
         const double f1 = (double)flt1[i * flt1_stride + j] - u;
         const double f2 = (double)flt2[i * flt2_stride + j] - u;
+#endif  // CONFIG_SKIP_SGR
         H[0][0] += f1 * f1;
         H[1][1] += f2 * f2;
         H[0][1] += f1 * f2;
@@ -318,8 +379,15 @@
         const double u = (double)(dat[i * dat_stride + j] << SGRPROJ_RST_BITS);
         const double s =
             (double)(src[i * src_stride + j] << SGRPROJ_RST_BITS) - u;
+#if CONFIG_SKIP_SGR
+        const double f1 =
+            (params->r1 > 0) ? (double)flt1[i * flt1_stride + j] - u : 0;
+        const double f2 =
+            (params->r2 > 0) ? (double)flt2[i * flt2_stride + j] - u : 0;
+#else   // CONFIG_SKIP_SGR
         const double f1 = (double)flt1[i * flt1_stride + j] - u;
         const double f2 = (double)flt2[i * flt2_stride + j] - u;
+#endif  // CONFIG_SKIP_SGR
         H[0][0] += f1 * f1;
         H[1][1] += f2 * f2;
         H[0][1] += f1 * f2;
@@ -334,20 +402,69 @@
   H[1][0] = H[0][1];
   C[0] /= size;
   C[1] /= size;
+#if CONFIG_SKIP_SGR
+  if (params->r1 == 0) {
+    // H matrix is now only the scalar H[1][1]
+    // C vector is now only the scalar C[1]
+    Det = H[1][1];
+    if (Det < 1e-8) return;  // ill-posed, return default values
+    x[0] = 0;
+    x[1] = C[1] / Det;
+
+    xq[0] = 0;
+    xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
+  } else if (params->r2 == 0) {
+    // H matrix is now only the scalar H[0][0]
+    // C vector is now only the scalar C[0]
+    Det = H[0][0];
+    if (Det < 1e-8) return;  // ill-posed, return default values
+    x[0] = C[0] / Det;
+    x[1] = 0;
+
+    xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
+    xq[1] = 0;
+  } else {
+    Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
+    if (Det < 1e-8) return;  // ill-posed, return default values
+    x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
+    x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
+
+    xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
+    xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
+  }
+#else   // CONFIG_SKIP_SGR
   Det = (H[0][0] * H[1][1] - H[0][1] * H[1][0]);
   if (Det < 1e-8) return;  // ill-posed, return default values
   x[0] = (H[1][1] * C[0] - H[0][1] * C[1]) / Det;
   x[1] = (H[0][0] * C[1] - H[1][0] * C[0]) / Det;
   xq[0] = (int)rint(x[0] * (1 << SGRPROJ_PRJ_BITS));
   xq[1] = (int)rint(x[1] * (1 << SGRPROJ_PRJ_BITS));
+#endif  // CONFIG_SKIP_SGR
 }
 
+#if CONFIG_SKIP_SGR
+void encode_xq(int *xq, int *xqd, const sgr_params_type *params) {
+  if (params->r1 == 0) {
+    xqd[0] = 0;
+    xqd[1] = clamp((1 << SGRPROJ_PRJ_BITS) - xq[1], SGRPROJ_PRJ_MIN1,
+                   SGRPROJ_PRJ_MAX1);
+  } else if (params->r2 == 0) {
+    xqd[0] = clamp(xq[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
+    xqd[1] = 0;
+  } else {
+    xqd[0] = clamp(xq[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
+    xqd[1] = clamp((1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1], SGRPROJ_PRJ_MIN1,
+                   SGRPROJ_PRJ_MAX1);
+  }
+}
+#else   // CONFIG_SKIP_SGR
 void encode_xq(int *xq, int *xqd) {
   xqd[0] = xq[0];
   xqd[0] = clamp(xqd[0], SGRPROJ_PRJ_MIN0, SGRPROJ_PRJ_MAX0);
   xqd[1] = (1 << SGRPROJ_PRJ_BITS) - xqd[0] - xq[1];
   xqd[1] = clamp(xqd[1], SGRPROJ_PRJ_MIN1, SGRPROJ_PRJ_MAX1);
 }
+#endif  // CONFIG_SKIP_SGR
 
 // Apply the self-guided filter across an entire restoration unit.
 static void apply_sgr(const sgr_params_type *params, const uint8_t *dat8,
@@ -386,19 +503,33 @@
          pu_height == RESTORATION_PROC_UNIT_SIZE);
 
   for (ep = 0; ep < SGRPROJ_PARAMS; ep++) {
+    const sgr_params_type *params = &sgr_params[ep];
     int exq[2];
-    apply_sgr(&sgr_params[ep], dat8, width, height, dat_stride,
-              use_highbitdepth, bit_depth, pu_width, pu_height, flt1, flt2,
-              flt_stride);
+
+    apply_sgr(params, dat8, width, height, dat_stride, use_highbitdepth,
+              bit_depth, pu_width, pu_height, flt1, flt2, flt_stride);
     aom_clear_system_state();
+#if CONFIG_SKIP_SGR
+    get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
+                      use_highbitdepth, flt1, flt_stride, flt2, flt_stride, exq,
+                      params);
+#else   // CONFIG_SKIP_SGR
     get_proj_subspace(src8, width, height, src_stride, dat8, dat_stride,
                       use_highbitdepth, flt1, flt_stride, flt2, flt_stride,
                       exq);
+#endif  // CONFIG_SKIP_SGR
     aom_clear_system_state();
+#if CONFIG_SKIP_SGR
+    encode_xq(exq, exqd, params);
+    int64_t err = finer_search_pixel_proj_error(
+        src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth,
+        flt1, flt_stride, flt2, flt_stride, 2, exqd, params);
+#else   // CONFIG_SKIP_SGR
     encode_xq(exq, exqd);
     int64_t err = finer_search_pixel_proj_error(
         src8, width, height, src_stride, dat8, dat_stride, use_highbitdepth,
         flt1, flt_stride, flt2, flt_stride, 2, exqd);
+#endif  // CONFIG_SKIP_SGR
     if (besterr == -1 || err < besterr) {
       bestep = ep;
       besterr = err;
@@ -417,6 +548,19 @@
 static int count_sgrproj_bits(SgrprojInfo *sgrproj_info,
                               SgrprojInfo *ref_sgrproj_info) {
   int bits = SGRPROJ_PARAMS_BITS;
+#if CONFIG_SKIP_SGR
+  const sgr_params_type *params = &sgr_params[sgrproj_info->ep];
+  if (params->r1 > 0)
+    bits += aom_count_primitive_refsubexpfin(
+        SGRPROJ_PRJ_MAX0 - SGRPROJ_PRJ_MIN0 + 1, SGRPROJ_PRJ_SUBEXP_K,
+        ref_sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0,
+        sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0);
+  if (params->r2 > 0)
+    bits += aom_count_primitive_refsubexpfin(
+        SGRPROJ_PRJ_MAX1 - SGRPROJ_PRJ_MIN1 + 1, SGRPROJ_PRJ_SUBEXP_K,
+        ref_sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1,
+        sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1);
+#else   // CONFIG_SKIP_SGR
   bits += aom_count_primitive_refsubexpfin(
       SGRPROJ_PRJ_MAX0 - SGRPROJ_PRJ_MIN0 + 1, SGRPROJ_PRJ_SUBEXP_K,
       ref_sgrproj_info->xqd[0] - SGRPROJ_PRJ_MIN0,
@@ -425,6 +569,7 @@
       SGRPROJ_PRJ_MAX1 - SGRPROJ_PRJ_MIN1 + 1, SGRPROJ_PRJ_SUBEXP_K,
       ref_sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1,
       sgrproj_info->xqd[1] - SGRPROJ_PRJ_MIN1);
+#endif  // CONFIG_SKIP_SGR
   return bits;
 }