Add domain transform recursive filter

This filter is meant to replace the bilateral filter,
but is currently added in addition to the bilateral filter
for testing.

Change-Id: Ia529701e69833d47c11b3367d5bf804eb8498079
diff --git a/av1/common/restoration.c b/av1/common/restoration.c
index 8eef650..08ab948 100644
--- a/av1/common/restoration.c
+++ b/av1/common/restoration.c
@@ -20,6 +20,16 @@
 #include "aom_mem/aom_mem.h"
 #include "aom_ports/mem.h"
 
+static int domaintxfmrf_vtable[DOMAINTXFMRF_ITERS][DOMAINTXFMRF_PARAMS][256];
+
+static const int domaintxfmrf_params[DOMAINTXFMRF_PARAMS] = {
+  48,  52,  56,  60,  64,  68,  72,  76,  80,  82,  84,  86,  88,
+  90,  92,  94,  96,  97,  98,  99,  100, 101, 102, 103, 104, 105,
+  106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118,
+  119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 130, 132, 134,
+  136, 138, 140, 142, 146, 150, 154, 158, 162, 166, 170, 174
+};
+
 #define BILATERAL_PARAM_PRECISION 16
 #define BILATERAL_AMP_RANGE 256
 #define BILATERAL_AMP_RANGE_SYM (2 * BILATERAL_AMP_RANGE + 1)
@@ -76,7 +86,26 @@
             : bilateral_level_to_params_arr[index];
 }
 
-void av1_loop_restoration_precal() {
+static void GenDomainTxfmRFVtable() {
+  int i, j;
+  const double sigma_s = sqrt(2.0);
+  for (i = 0; i < DOMAINTXFMRF_ITERS; ++i) {
+    const int nm = (1 << (DOMAINTXFMRF_ITERS - i - 1));
+    const double A = exp(-DOMAINTXFMRF_MULT / (sigma_s * nm));
+    for (j = 0; j < DOMAINTXFMRF_PARAMS; ++j) {
+      const double sigma_r =
+          (double)domaintxfmrf_params[j] / DOMAINTXFMRF_SIGMA_SCALE;
+      const double scale = sigma_s / sigma_r;
+      int k;
+      for (k = 0; k < 256; ++k) {
+        domaintxfmrf_vtable[i][j][k] =
+            RINT(DOMAINTXFMRF_VTABLE_PREC * pow(A, 1.0 + k * scale));
+      }
+    }
+  }
+}
+
+static void GenBilateralTables() {
   int i;
   for (i = 0; i < BILATERAL_LEVELS_KF; i++) {
     const BilateralParamsType param = av1_bilateral_level_to_params(i, 1);
@@ -140,6 +169,11 @@
   }
 }
 
+void av1_loop_restoration_precal() {
+  GenBilateralTables();
+  GenDomainTxfmRFVtable();
+}
+
 int av1_bilateral_level_bits(const AV1_COMMON *const cm) {
   return cm->frame_type == KEY_FRAME ? BILATERAL_LEVEL_BITS_KF
                                      : BILATERAL_LEVEL_BITS;
@@ -157,18 +191,20 @@
                           &rst->nhtiles, &rst->nvtiles);
   if (rsi->frame_restoration_type == RESTORE_WIENER) {
     for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-      rsi->wiener_info[tile_idx].vfilter[RESTORATION_HALFWIN] =
-          rsi->wiener_info[tile_idx].hfilter[RESTORATION_HALFWIN] =
-              RESTORATION_FILT_STEP;
-      for (i = 0; i < RESTORATION_HALFWIN; ++i) {
-        rsi->wiener_info[tile_idx].vfilter[RESTORATION_WIN - 1 - i] =
-            rsi->wiener_info[tile_idx].vfilter[i];
-        rsi->wiener_info[tile_idx].hfilter[RESTORATION_WIN - 1 - i] =
-            rsi->wiener_info[tile_idx].hfilter[i];
-        rsi->wiener_info[tile_idx].vfilter[RESTORATION_HALFWIN] -=
-            2 * rsi->wiener_info[tile_idx].vfilter[i];
-        rsi->wiener_info[tile_idx].hfilter[RESTORATION_HALFWIN] -=
-            2 * rsi->wiener_info[tile_idx].hfilter[i];
+      if (rsi->wiener_info[tile_idx].level) {
+        rsi->wiener_info[tile_idx].vfilter[RESTORATION_HALFWIN] =
+            rsi->wiener_info[tile_idx].hfilter[RESTORATION_HALFWIN] =
+                RESTORATION_FILT_STEP;
+        for (i = 0; i < RESTORATION_HALFWIN; ++i) {
+          rsi->wiener_info[tile_idx].vfilter[RESTORATION_WIN - 1 - i] =
+              rsi->wiener_info[tile_idx].vfilter[i];
+          rsi->wiener_info[tile_idx].hfilter[RESTORATION_WIN - 1 - i] =
+              rsi->wiener_info[tile_idx].hfilter[i];
+          rsi->wiener_info[tile_idx].vfilter[RESTORATION_HALFWIN] -=
+              2 * rsi->wiener_info[tile_idx].vfilter[i];
+          rsi->wiener_info[tile_idx].hfilter[RESTORATION_HALFWIN] -=
+              2 * rsi->wiener_info[tile_idx].hfilter[i];
+        }
       }
     }
   } else if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
@@ -678,6 +714,127 @@
   aom_free(tmpbuf);
 }
 
+static void apply_domaintxfmrf_hor(int iter, int param, uint8_t *img, int width,
+                                   int height, int img_stride, int32_t *dat,
+                                   int dat_stride) {
+  int i, j;
+  for (i = 0; i < height; ++i) {
+    uint8_t *ip = &img[i * img_stride];
+    int32_t *dp = &dat[i * dat_stride];
+    *dp *= DOMAINTXFMRF_VTABLE_PREC;
+    dp++;
+    ip++;
+    // left to right
+    for (j = 1; j < width; ++j, dp++, ip++) {
+      const int v = domaintxfmrf_vtable[iter][param][abs(ip[0] - ip[-1])];
+      dp[0] = dp[0] * (DOMAINTXFMRF_VTABLE_PREC - v) +
+              ((v * dp[-1] + DOMAINTXFMRF_VTABLE_PREC / 2) >>
+               DOMAINTXFMRF_VTABLE_PRECBITS);
+    }
+    // right to left
+    dp -= 2;
+    ip -= 2;
+    for (j = width - 2; j >= 0; --j, dp--, ip--) {
+      const int v = domaintxfmrf_vtable[iter][param][abs(ip[1] - ip[0])];
+      dp[0] = (dp[0] * (DOMAINTXFMRF_VTABLE_PREC - v) + v * dp[1] +
+               DOMAINTXFMRF_VTABLE_PREC / 2) >>
+              DOMAINTXFMRF_VTABLE_PRECBITS;
+    }
+  }
+}
+
+static void apply_domaintxfmrf_ver(int iter, int param, uint8_t *img, int width,
+                                   int height, int img_stride, int32_t *dat,
+                                   int dat_stride) {
+  int i, j;
+  for (j = 0; j < width; ++j) {
+    uint8_t *ip = &img[j];
+    int32_t *dp = &dat[j];
+    dp += dat_stride;
+    ip += img_stride;
+    // top to bottom
+    for (i = 1; i < height; ++i, dp += dat_stride, ip += img_stride) {
+      const int v =
+          domaintxfmrf_vtable[iter][param][abs(ip[0] - ip[-img_stride])];
+      dp[0] = (dp[0] * (DOMAINTXFMRF_VTABLE_PREC - v) +
+               (dp[-dat_stride] * v + DOMAINTXFMRF_VTABLE_PREC / 2)) >>
+              DOMAINTXFMRF_VTABLE_PRECBITS;
+    }
+    // bottom to top
+    dp -= 2 * dat_stride;
+    ip -= 2 * img_stride;
+    for (i = height - 2; i >= 0; --i, dp -= dat_stride, ip -= img_stride) {
+      const int v =
+          domaintxfmrf_vtable[iter][param][abs(ip[img_stride] - ip[0])];
+      dp[0] = (dp[0] * (DOMAINTXFMRF_VTABLE_PREC - v) + dp[dat_stride] * v +
+               DOMAINTXFMRF_VTABLE_PREC / 2) >>
+              DOMAINTXFMRF_VTABLE_PRECBITS;
+    }
+  }
+}
+
+static void apply_domaintxfmrf_reduce_prec(int32_t *dat, int width, int height,
+                                           int dat_stride) {
+  int i, j;
+  for (i = 0; i < height; ++i) {
+    for (j = 0; j < width; ++j) {
+      dat[i * dat_stride + j] = ROUND_POWER_OF_TWO_SIGNED(
+          dat[i * dat_stride + j], DOMAINTXFMRF_VTABLE_PRECBITS);
+    }
+  }
+}
+
+void av1_domaintxfmrf_restoration(uint8_t *dgd, int width, int height,
+                                  int stride, int param) {
+  int32_t dat[RESTORATION_TILEPELS_MAX];
+  int i, j, t;
+  for (i = 0; i < height; ++i) {
+    for (j = 0; j < width; ++j) {
+      dat[i * width + j] = dgd[i * stride + j];
+    }
+  }
+  for (t = 0; t < DOMAINTXFMRF_ITERS; ++t) {
+    apply_domaintxfmrf_hor(t, param, dgd, width, height, stride, dat, width);
+    apply_domaintxfmrf_ver(t, param, dgd, width, height, stride, dat, width);
+    apply_domaintxfmrf_reduce_prec(dat, width, height, width);
+  }
+  for (i = 0; i < height; ++i) {
+    for (j = 0; j < width; ++j) {
+      dgd[i * stride + j] = clip_pixel(dat[i * width + j]);
+    }
+  }
+}
+
+static void loop_domaintxfmrf_filter_tile(uint8_t *data, int tile_idx,
+                                          int width, int height, int stride,
+                                          RestorationInternal *rst,
+                                          void *tmpbuf) {
+  const int tile_width = rst->tile_width >> rst->subsampling_x;
+  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  int h_start, h_end, v_start, v_end;
+  (void)tmpbuf;
+
+  if (rst->rsi->domaintxfmrf_info[tile_idx].level == 0) return;
+  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
+                           tile_width, tile_height, width, height, 0, 0,
+                           &h_start, &h_end, &v_start, &v_end);
+  av1_domaintxfmrf_restoration(data + h_start + v_start * stride,
+                               h_end - h_start, v_end - v_start, stride,
+                               rst->rsi->domaintxfmrf_info[tile_idx].sigma_r);
+}
+
+static void loop_domaintxfmrf_filter(uint8_t *data, int width, int height,
+                                     int stride, RestorationInternal *rst,
+                                     uint8_t *tmpdata, int tmpstride) {
+  int tile_idx;
+  (void)tmpdata;
+  (void)tmpstride;
+  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+    loop_domaintxfmrf_filter_tile(data, tile_idx, width, height, stride, rst,
+                                  NULL);
+  }
+}
+
 static void loop_switchable_filter(uint8_t *data, int width, int height,
                                    int stride, RestorationInternal *rst,
                                    uint8_t *tmpdata, int tmpstride) {
@@ -703,6 +860,9 @@
     } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_SGRPROJ) {
       loop_sgrproj_filter_tile(data, tile_idx, width, height, stride, rst,
                                tmpbuf);
+    } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_DOMAINTXFMRF) {
+      loop_domaintxfmrf_filter_tile(data, tile_idx, width, height, stride, rst,
+                                    tmpbuf);
     }
   }
   aom_free(tmpbuf);
@@ -918,6 +1078,130 @@
   aom_free(tmpbuf);
 }
 
+static void apply_domaintxfmrf_hor_highbd(int iter, int param, uint16_t *img,
+                                          int width, int height, int img_stride,
+                                          int32_t *dat, int dat_stride,
+                                          int bd) {
+  const int shift = (bd - 8);
+  int i, j;
+  for (i = 0; i < height; ++i) {
+    uint16_t *ip = &img[i * img_stride];
+    int32_t *dp = &dat[i * dat_stride];
+    *dp *= DOMAINTXFMRF_VTABLE_PREC;
+    dp++;
+    ip++;
+    // left to right
+    for (j = 1; j < width; ++j, dp++, ip++) {
+      const int v =
+          domaintxfmrf_vtable[iter][param]
+                             [abs((ip[0] >> shift) - (ip[-1] >> shift))];
+      dp[0] = dp[0] * (DOMAINTXFMRF_VTABLE_PREC - v) +
+              ((v * dp[-1] + DOMAINTXFMRF_VTABLE_PREC / 2) >>
+               DOMAINTXFMRF_VTABLE_PRECBITS);
+    }
+    // right to left
+    dp -= 2;
+    ip -= 2;
+    for (j = width - 2; j >= 0; --j, dp--, ip--) {
+      const int v =
+          domaintxfmrf_vtable[iter][param]
+                             [abs((ip[1] >> shift) - (ip[0] >> shift))];
+      dp[0] = (dp[0] * (DOMAINTXFMRF_VTABLE_PREC - v) + v * dp[1] +
+               DOMAINTXFMRF_VTABLE_PREC / 2) >>
+              DOMAINTXFMRF_VTABLE_PRECBITS;
+    }
+  }
+}
+
+static void apply_domaintxfmrf_ver_highbd(int iter, int param, uint16_t *img,
+                                          int width, int height, int img_stride,
+                                          int32_t *dat, int dat_stride,
+                                          int bd) {
+  int i, j;
+  const int shift = (bd - 8);
+  for (j = 0; j < width; ++j) {
+    uint16_t *ip = &img[j];
+    int32_t *dp = &dat[j];
+    dp += dat_stride;
+    ip += img_stride;
+    // top to bottom
+    for (i = 1; i < height; ++i, dp += dat_stride, ip += img_stride) {
+      const int v = domaintxfmrf_vtable[iter][param][abs(
+          (ip[0] >> shift) - (ip[-img_stride] >> shift))];
+      dp[0] = (dp[0] * (DOMAINTXFMRF_VTABLE_PREC - v) +
+               (dp[-dat_stride] * v + DOMAINTXFMRF_VTABLE_PREC / 2)) >>
+              DOMAINTXFMRF_VTABLE_PRECBITS;
+    }
+    // bottom to top
+    dp -= 2 * dat_stride;
+    ip -= 2 * img_stride;
+    for (i = height - 2; i >= 0; --i, dp -= dat_stride, ip -= img_stride) {
+      const int v = domaintxfmrf_vtable[iter][param][abs(
+          (ip[img_stride] >> shift) - (ip[0] >> shift))];
+      dp[0] = (dp[0] * (DOMAINTXFMRF_VTABLE_PREC - v) + dp[dat_stride] * v +
+               DOMAINTXFMRF_VTABLE_PREC / 2) >>
+              DOMAINTXFMRF_VTABLE_PRECBITS;
+    }
+  }
+}
+
+void av1_domaintxfmrf_restoration_highbd(uint16_t *dgd, int width, int height,
+                                         int stride, int param, int bit_depth) {
+  int32_t dat[RESTORATION_TILEPELS_MAX];
+  int i, j, t;
+  for (i = 0; i < height; ++i) {
+    for (j = 0; j < width; ++j) {
+      dat[i * width + j] = dgd[i * stride + j];
+    }
+  }
+  for (t = 0; t < DOMAINTXFMRF_ITERS; ++t) {
+    apply_domaintxfmrf_hor_highbd(t, param, dgd, width, height, stride, dat,
+                                  width, bit_depth);
+    apply_domaintxfmrf_ver_highbd(t, param, dgd, width, height, stride, dat,
+                                  width, bit_depth);
+    apply_domaintxfmrf_reduce_prec(dat, width, height, width);
+  }
+  for (i = 0; i < height; ++i) {
+    for (j = 0; j < width; ++j) {
+      dgd[i * stride + j] = clip_pixel_highbd(dat[i * width + j], bit_depth);
+    }
+  }
+}
+
+static void loop_domaintxfmrf_filter_tile_highbd(uint16_t *data, int tile_idx,
+                                                 int width, int height,
+                                                 int stride,
+                                                 RestorationInternal *rst,
+                                                 int bit_depth, void *tmpbuf) {
+  const int tile_width = rst->tile_width >> rst->subsampling_x;
+  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  int h_start, h_end, v_start, v_end;
+  (void)tmpbuf;
+
+  if (rst->rsi->domaintxfmrf_info[tile_idx].level == 0) return;
+  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
+                           tile_width, tile_height, width, height, 0, 0,
+                           &h_start, &h_end, &v_start, &v_end);
+  av1_domaintxfmrf_restoration_highbd(
+      data + h_start + v_start * stride, h_end - h_start, v_end - v_start,
+      stride, rst->rsi->domaintxfmrf_info[tile_idx].sigma_r, bit_depth);
+}
+
+static void loop_domaintxfmrf_filter_highbd(uint8_t *data8, int width,
+                                            int height, int stride,
+                                            RestorationInternal *rst,
+                                            uint8_t *tmpdata, int tmpstride,
+                                            int bit_depth) {
+  int tile_idx;
+  uint16_t *data = CONVERT_TO_SHORTPTR(data8);
+  (void)tmpdata;
+  (void)tmpstride;
+  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+    loop_domaintxfmrf_filter_tile_highbd(data, tile_idx, width, height, stride,
+                                         rst, bit_depth, NULL);
+  }
+}
+
 static void loop_switchable_filter_highbd(uint8_t *data8, int width, int height,
                                           int stride, RestorationInternal *rst,
                                           uint8_t *tmpdata8, int tmpstride,
@@ -946,6 +1230,9 @@
     } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_SGRPROJ) {
       loop_sgrproj_filter_tile_highbd(data, tile_idx, width, height, stride,
                                       rst, bit_depth, tmpbuf);
+    } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_DOMAINTXFMRF) {
+      loop_domaintxfmrf_filter_tile_highbd(data, tile_idx, width, height,
+                                           stride, rst, bit_depth, tmpbuf);
     }
   }
   aom_free(tmpbuf);
@@ -962,14 +1249,20 @@
   const int uvstart = ystart >> cm->subsampling_y;
   int yend = end_mi_row << MI_SIZE_LOG2;
   int uvend = yend >> cm->subsampling_y;
-  restore_func_type restore_funcs[RESTORE_TYPES] = { NULL, loop_sgrproj_filter,
+  restore_func_type restore_funcs[RESTORE_TYPES] = { NULL,
+                                                     loop_sgrproj_filter,
                                                      loop_bilateral_filter,
                                                      loop_wiener_filter,
+                                                     loop_domaintxfmrf_filter,
                                                      loop_switchable_filter };
 #if CONFIG_AOM_HIGHBITDEPTH
   restore_func_highbd_type restore_funcs_highbd[RESTORE_TYPES] = {
-    NULL, loop_sgrproj_filter_highbd, loop_bilateral_filter_highbd,
-    loop_wiener_filter_highbd, loop_switchable_filter_highbd
+    NULL,
+    loop_sgrproj_filter_highbd,
+    loop_bilateral_filter_highbd,
+    loop_wiener_filter_highbd,
+    loop_domaintxfmrf_filter_highbd,
+    loop_switchable_filter_highbd
   };
 #endif  // CONFIG_AOM_HIGHBITDEPTH
   restore_func_type restore_func =