Refactor/simplify Loop-Restoration frame filters

Removes some unnecessary flags and streamlines the implementation.
diff --git a/av1/common/blockd.c b/av1/common/blockd.c
index ce11661..09d4abb 100644
--- a/av1/common/blockd.c
+++ b/av1/common/blockd.c
@@ -324,9 +324,6 @@
     bank->bank_size_for_class[c_id] = 0;
     bank->bank_ptr_for_class[c_id] = 0;
   }
-#if CONFIG_COMBINE_PC_NS_WIENER
-  bank->frame_filter_predictors_are_set = 0;
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
 }
 
 // Add a new filter to bank
diff --git a/av1/common/blockd.h b/av1/common/blockd.h
index e4c5d2a..34b10a7 100644
--- a/av1/common/blockd.h
+++ b/av1/common/blockd.h
@@ -2001,10 +2001,6 @@
   /*!
    * Whether frame-level filters are on or off.
    */
-  int frame_filters_on;
-  /*!
-   * Filter data - taps
-   */
   DECLARE_ALIGNED(16, int16_t,
                   allfiltertaps[WIENERNS_MAX_CLASSES * WIENERNS_YUV_MAX]);
 #if CONFIG_LR_MERGE_COEFFS
@@ -2020,10 +2016,6 @@
    * the first bank slot and in turn used as frame filter predictors.
    */
   int match_indices[WIENERNS_MAX_CLASSES];
-#if CONFIG_TEMP_LR
-  // whether frame filter is predicted from a reference picture
-  uint8_t temporal_pred_flag;
-#endif  // CONFIG_TEMP_LR
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
 } WienerNonsepInfo;
 
@@ -2041,13 +2033,6 @@
    * Pointer to the most current filter for each class.
    */
   int bank_ptr_for_class[WIENERNS_MAX_CLASSES];
-#if CONFIG_COMBINE_PC_NS_WIENER
-  /*!
-   * Whether the bank has been initialized with predictions used to better
-   * code the frame-level filters.
-   */
-  int frame_filter_predictors_are_set;
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
 } WienerNonsepInfoBank;
 
 int16_t *nsfilter_taps(WienerNonsepInfo *nsinfo, int wiener_class_id);
diff --git a/av1/common/restoration.c b/av1/common/restoration.c
index 49d5180..886d5ed 100644
--- a/av1/common/restoration.c
+++ b/av1/common/restoration.c
@@ -2957,8 +2957,6 @@
     WienerNonsepInfoBank *bank, const WienerNonsepInfo *reference,
     const int *match_indices, int base_qindex, int class_id,
     int16_t *frame_filter_dictionary, int dict_stride) {
-  assert(!bank->frame_filter_predictors_are_set);
-
   const int is_uv = 0;
   const WienernsFilterParameters *nsfilter_params =
       get_wienerns_parameters(base_qindex, is_uv);
@@ -2992,7 +2990,6 @@
 #if CONFIG_TEMP_LR
 void av1_copy_rst_frame_filters(RestorationInfo *to,
                                 const RestorationInfo *from) {
-  assert(from->frame_filters_on);
   to->frame_filters_on = from->frame_filters_on;
   to->num_filter_classes = from->num_filter_classes;
   to->frame_filters = from->frame_filters;
diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c
index f88521d..2dff540 100644
--- a/av1/decoder/decodeframe.c
+++ b/av1/decoder/decodeframe.c
@@ -127,7 +127,7 @@
 }
 
 static AOM_INLINE void loop_restoration_read_sb_coeffs(
-    const AV1_COMMON *const cm, MACROBLOCKD *xd, aom_reader *const r, int plane,
+    AV1_COMMON *cm, MACROBLOCKD *xd, aom_reader *const r, int plane,
     int runit_idx
 #if CONFIG_COMBINE_PC_NS_WIENER
     ,
@@ -2785,34 +2785,6 @@
                                             frame_filter_dictionary, dict_stride
 #endif
             );
-#if CONFIG_COMBINE_PC_NS_WIENER
-            if (plane == AOM_PLANE_Y) {
-              // TODO: Needs to be fixed.
-              RestorationInfo *rsi = (RestorationInfo *)cm->rst_info + plane;
-              if (rsi->frame_filters_on && !rsi->frame_filters_initialized &&
-                  rsi->unit_info[runit_idx].restoration_type ==
-                      RESTORE_WIENER_NONSEP
-#if CONFIG_TEMP_LR
-                  && !rsi->temporal_pred_flag
-#endif  // CONFIG_TEMP_LR
-              ) {
-                rsi->frame_filters_initialized = 1;
-                const WienerNonsepInfoBank *bank = &xd->wienerns_info[plane];
-                assert(bank->frame_filter_predictors_are_set);
-                rsi->frame_filters.num_classes = bank->filter[0].num_classes;
-                for (int c_id = 0; c_id < rsi->frame_filters.num_classes;
-                     ++c_id) {
-                  copy_nsfilter_taps_for_class(
-                      &rsi->frame_filters,
-                      av1_constref_from_wienerns_bank(bank, 0, c_id), c_id);
-                }
-#if CONFIG_TEMP_LR
-                av1_copy_rst_frame_filters(&cm->cur_frame->rst_info[plane],
-                                           rsi);
-#endif  // CONFIG_TEMP_LR
-              }
-            }
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
           }
         }
       }
@@ -3723,45 +3695,130 @@
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
 
 #if CONFIG_LR_IMPROVEMENTS
-static void read_wienerns_filter(MACROBLOCKD *xd, int is_uv,
-                                 WienerNonsepInfo *wienerns_info,
-                                 WienerNonsepInfoBank *bank, aom_reader *rb
 #if CONFIG_COMBINE_PC_NS_WIENER
-                                 ,
-                                 int base_qindex,
-                                 int16_t *frame_filter_dictionary,
-                                 int dict_stride
+static void read_wienerns_framefilters(AV1_COMMON *cm, MACROBLOCKD *xd,
+                                       int plane, aom_reader *rb,
+                                       int16_t *frame_filter_dictionary,
+                                       int dict_stride) {
+  const int base_qindex = cm->quant_params.base_qindex;
+  const int is_uv = plane != AOM_PLANE_Y;
+  RestorationInfo *rsi = &cm->rst_info[plane];
+  assert(!is_uv);
+  assert(rsi->frame_filters_on && !rsi->frame_filters_initialized);
+  int skip_filter_read_for_class[WIENERNS_MAX_CLASSES] = { 0 };
+  const int num_classes = rsi->num_filter_classes;
+  rsi->frame_filters.num_classes = num_classes;
+  assert(num_classes <= WIENERNS_MAX_CLASSES);
+#if CONFIG_LR_MERGE_COEFFS
+#if CONFIG_TEMP_LR
+  assert(!rsi->temporal_pred_flag);
+#endif  // CONFIG_TEMP_LR
+  read_match_indices(&rsi->frame_filters, rb);
+  for (int c_id = 0; c_id < num_classes; ++c_id) {
+    const int exact_match = aom_read_symbol(rb, xd->tile_ctx->merged_param_cdf,
+                                            2, ACCT_INFO("exact_match"));
+    skip_filter_read_for_class[c_id] = exact_match;
+  }
+#else
+  (void)xd;
+#endif  // CONFIG_LR_MERGE_COEFFS
+  const WienernsFilterParameters *nsfilter_params =
+      get_wienerns_parameters(base_qindex, is_uv);
+  const int(*wienerns_coeffs)[WIENERNS_COEFCFG_LEN] = nsfilter_params->coeffs;
+  WienerNonsepInfoBank bank = { 0 };
+  bank.filter[0].num_classes = num_classes;
+  for (int c_id = 0; c_id < num_classes; ++c_id) {
+    fill_first_slot_of_bank_with_filter_match(
+        &bank, &rsi->frame_filters, rsi->frame_filters.match_indices,
+        base_qindex, c_id, frame_filter_dictionary, dict_stride);
+    if (skip_filter_read_for_class[c_id]) {
+      copy_nsfilter_taps_for_class(
+          &rsi->frame_filters, av1_constref_from_wienerns_bank(&bank, 0, c_id),
+          c_id);
+      continue;
+    }
+    const WienerNonsepInfo *ref_wienerns_info =
+        av1_constref_from_wienerns_bank(&bank, 0, c_id);
+    assert(ref_wienerns_info->num_classes == num_classes);
+    int16_t *wienerns_info_nsfilter = nsfilter_taps(&rsi->frame_filters, c_id);
+    const int16_t *ref_wienerns_info_nsfilter =
+        const_nsfilter_taps(ref_wienerns_info, c_id);
+
+    memset(wienerns_info_nsfilter, 0,
+           nsfilter_params->ncoeffs * sizeof(wienerns_info_nsfilter[0]));
+
+    const int beg_feat = 0;
+    int end_feat = nsfilter_params->ncoeffs;
+    if (end_feat > 6) {
+      const int filter_length_bit =
+          aom_read_symbol(rb, xd->tile_ctx->wienerns_length_cdf[is_uv], 2,
+                          ACCT_INFO("wienerns_length"));
+      end_feat = filter_length_bit ? nsfilter_params->ncoeffs : 6;
+    }
+    assert((end_feat & 1) == 0);
+
+    int uv_sym = 0;
+    if (is_uv && end_feat > 6) {
+      uv_sym = aom_read_symbol(rb, xd->tile_ctx->wienerns_uv_sym_cdf, 2,
+                               ACCT_INFO("wienerns_uv_sym"));
+    }
+
+    for (int i = beg_feat; i < end_feat; ++i) {
+#if ENABLE_LR_4PART_CODE
+      wienerns_info_nsfilter[i] =
+          aom_read_4part_wref(
+              rb,
+              ref_wienerns_info_nsfilter[i] -
+                  wienerns_coeffs[i - beg_feat][WIENERNS_MIN_ID],
+              xd->tile_ctx->wienerns_4part_cdf
+                  [wienerns_coeffs[i - beg_feat][WIENERNS_PAR_ID]],
+              wienerns_coeffs[i - beg_feat][WIENERNS_BIT_ID],
+              ACCT_INFO("wienerns_info_nsfilter")) +
+          wienerns_coeffs[i - beg_feat][WIENERNS_MIN_ID];
+#else
+      wienerns_info_nsfilter[i] =
+          aom_read_primitive_refsubexpfin(
+              rb, (1 << wienerns_coeffs[i - beg_feat][WIENERNS_BIT_ID]),
+              wienerns_coeffs[i - beg_feat][WIENERNS_PAR_ID],
+              ref_wienerns_info_nsfilter[i] -
+                  wienerns_coeffs[i - beg_feat][WIENERNS_MIN_ID],
+              ACCT_INFO("wienerns_info_nsfilter")) +
+          wienerns_coeffs[i - beg_feat][WIENERNS_MIN_ID];
+#endif  // ENABLE_LR_4PART_CODE
+      if (uv_sym && i >= 6) {
+        // Fill in symmetrical tap without reading it
+        wienerns_info_nsfilter[i + 1] = wienerns_info_nsfilter[i];
+        i++;
+      }
+    }
+  }
+  rsi->frame_filters_initialized = 1;
+#if CONFIG_TEMP_LR
+  av1_copy_rst_frame_filters(&cm->cur_frame->rst_info[plane], rsi);
+#endif  // CONFIG_TEMP_LR
+}
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
-) {
+
+static void read_wienerns_filter(MACROBLOCKD *xd, int is_uv,
+                                 const RestorationInfo *rsi,
+                                 WienerNonsepInfo *wienerns_info,
+                                 WienerNonsepInfoBank *bank, aom_reader *rb) {
   int skip_filter_read_for_class[WIENERNS_MAX_CLASSES] = { 0 };
   int ref_for_class[WIENERNS_MAX_CLASSES] = { 0 };
   const int num_classes = wienerns_info->num_classes;
   assert(num_classes <= WIENERNS_MAX_CLASSES);
 #if CONFIG_LR_MERGE_COEFFS
 #if CONFIG_COMBINE_PC_NS_WIENER
-  const int skip_filter_read_all_classes =
-#if CONFIG_TEMP_LR
-      wienerns_info->temporal_pred_flag ||
-#endif  // CONFIG_TEMP_LR
-      (wienerns_info->frame_filters_on &&
-       bank->frame_filter_predictors_are_set);
-  if (
-#if CONFIG_TEMP_LR
-      !wienerns_info->temporal_pred_flag &&
-#endif  // CONFIG_TEMP_LR
-      wienerns_info->frame_filters_on &&
-      !bank->frame_filter_predictors_are_set) {
-    read_match_indices(wienerns_info, rb);
+  assert(IMPLIES(rsi->frame_filters_on, rsi->frame_filters_initialized));
+  if (rsi->frame_filters_on && rsi->frame_filters_initialized) {
+    assert(!is_uv);
+    for (int c_id = 0; c_id < num_classes; ++c_id) {
+      copy_nsfilter_taps_for_class(wienerns_info, &rsi->frame_filters, c_id);
+    }
+    return;
   }
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
   for (int c_id = 0; c_id < num_classes; ++c_id) {
-#if CONFIG_COMBINE_PC_NS_WIENER
-    if (skip_filter_read_all_classes) {
-      skip_filter_read_for_class[c_id] = 1;
-      ref_for_class[c_id] = 0;  // last filter in bank.
-      continue;
-    }
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
     const int exact_match = aom_read_symbol(rb, xd->tile_ctx->merged_param_cdf,
                                             2, ACCT_INFO("exact_match"));
     int ref;
@@ -3779,14 +3836,6 @@
       get_wienerns_parameters(xd->current_base_qindex, is_uv);
   const int(*wienerns_coeffs)[WIENERNS_COEFCFG_LEN] = nsfilter_params->coeffs;
   for (int c_id = 0; c_id < num_classes; ++c_id) {
-#if CONFIG_COMBINE_PC_NS_WIENER
-    if (wienerns_info->frame_filters_on &&
-        !bank->frame_filter_predictors_are_set) {
-      fill_first_slot_of_bank_with_filter_match(
-          bank, wienerns_info, wienerns_info->match_indices, base_qindex, c_id,
-          frame_filter_dictionary, dict_stride);
-    }
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
     if (skip_filter_read_for_class[c_id]) {
       copy_nsfilter_taps_for_class(
           wienerns_info,
@@ -3852,50 +3901,24 @@
     }
     av1_add_to_wienerns_bank(bank, wienerns_info, c_id);
   }
-#if CONFIG_COMBINE_PC_NS_WIENER
-  if (wienerns_info->frame_filters_on && !bank->frame_filter_predictors_are_set)
-    bank->frame_filter_predictors_are_set = 1;
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
 }
 #endif  // CONFIG_LR_IMPROVEMENTS
 
 static AOM_INLINE void loop_restoration_read_sb_coeffs(
-    const AV1_COMMON *const cm, MACROBLOCKD *xd, aom_reader *const r, int plane,
+    AV1_COMMON *cm, MACROBLOCKD *xd, aom_reader *const r, int plane,
     int runit_idx
 #if CONFIG_COMBINE_PC_NS_WIENER
     ,
     int16_t *frame_filter_dictionary, int dict_stride
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
 ) {
-  const RestorationInfo *rsi = &cm->rst_info[plane];
+  RestorationInfo *rsi = &cm->rst_info[plane];
   RestorationUnitInfo *rui = &rsi->unit_info[runit_idx];
   assert(rsi->frame_restoration_type != RESTORE_NONE);
 
   assert(!cm->features.all_lossless);
 
   const int wiener_win = (plane > 0) ? WIENER_WIN_CHROMA : WIENER_WIN;
-#if CONFIG_COMBINE_PC_NS_WIENER
-  if (plane == AOM_PLANE_Y && rsi->frame_filters_on &&
-      rsi->frame_filters_initialized) {
-    WienerNonsepInfoBank *bank = &xd->wienerns_info[plane];
-    if (!bank->frame_filter_predictors_are_set) {
-      bank->filter[0].num_classes = rsi->frame_filters.num_classes;
-      for (int c_id = 0; c_id < rsi->frame_filters.num_classes; ++c_id)
-        assert(bank->bank_size_for_class[c_id] == 0);
-      av1_add_to_wienerns_bank(bank, &rsi->frame_filters, ALL_WIENERNS_CLASSES);
-      bank->frame_filter_predictors_are_set = 1;
-
-#if CONFIG_TEMP_LR
-      // TODO: Is this needed?
-      av1_copy_rst_frame_filters(&cm->cur_frame->rst_info[plane], rsi);
-#endif  // CONFIG_TEMP_LR
-    }
-  }
-  rui->wienerns_info.frame_filters_on = rsi->frame_filters_on;
-#if CONFIG_TEMP_LR
-  rui->wienerns_info.temporal_pred_flag = rsi->temporal_pred_flag;
-#endif  // CONFIG_TEMP_LR
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
 #if CONFIG_LR_IMPROVEMENTS
   rui->wienerns_info.num_classes = rsi->num_filter_classes;
 #endif  // CONFIG_LR_IMPROVEMENTS
@@ -3929,14 +3952,13 @@
         break;
 #if CONFIG_LR_IMPROVEMENTS
       case RESTORE_WIENER_NONSEP:
-        read_wienerns_filter(xd, plane != AOM_PLANE_Y, &rui->wienerns_info,
-                             &xd->wienerns_info[plane], r
 #if CONFIG_COMBINE_PC_NS_WIENER
-                             ,
-                             cm->quant_params.base_qindex,
-                             frame_filter_dictionary, dict_stride
+        if (rsi->frame_filters_on && !rsi->frame_filters_initialized)
+          read_wienerns_framefilters(cm, xd, plane, r, frame_filter_dictionary,
+                                     dict_stride);
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
-        );
+        read_wienerns_filter(xd, plane != AOM_PLANE_Y, rsi, &rui->wienerns_info,
+                             &xd->wienerns_info[plane], r);
         break;
       case RESTORE_PC_WIENER:
         // No side-information for now.
@@ -3966,14 +3988,13 @@
     if (aom_read_symbol(r, xd->tile_ctx->wienerns_restore_cdf, 2,
                         ACCT_INFO("wienerns_restore_cdf"))) {
       rui->restoration_type = RESTORE_WIENER_NONSEP;
-      read_wienerns_filter(xd, plane != AOM_PLANE_Y, &rui->wienerns_info,
-                           &xd->wienerns_info[plane], r
 #if CONFIG_COMBINE_PC_NS_WIENER
-                           ,
-                           cm->quant_params.base_qindex,
-                           frame_filter_dictionary, dict_stride
+      if (rsi->frame_filters_on && !rsi->frame_filters_initialized)
+        read_wienerns_framefilters(cm, xd, plane, r, frame_filter_dictionary,
+                                   dict_stride);
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
-      );
+      read_wienerns_filter(xd, plane != AOM_PLANE_Y, rsi, &rui->wienerns_info,
+                           &xd->wienerns_info[plane], r);
     } else {
       rui->restoration_type = RESTORE_NONE;
     }
@@ -3993,6 +4014,7 @@
           1) == 0);
 #endif  // CONFIG_LR_IMPROVEMENTS
 }
+
 static AOM_INLINE void setup_loopfilter(AV1_COMMON *cm,
                                         struct aom_read_bit_buffer *rb) {
   const int num_planes = av1_num_planes(cm);
diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c
index c2b7935..81dd18b 100644
--- a/av1/encoder/bitstream.c
+++ b/av1/encoder/bitstream.c
@@ -79,7 +79,7 @@
 }
 
 static AOM_INLINE void loop_restoration_write_sb_coeffs(
-    const AV1_COMMON *const cm, MACROBLOCKD *xd, const RestorationUnitInfo *rui,
+    AV1_COMMON *cm, MACROBLOCKD *xd, const RestorationUnitInfo *rui,
     aom_writer *const w, int plane, FRAME_COUNTS *counts
 #if CONFIG_COMBINE_PC_NS_WIENER
     ,
@@ -3813,7 +3813,7 @@
     const PARTITION_TREE *ptree_luma,
 #endif  // CONFIG_EXT_RECUR_PARTITIONS
     int mi_row, int mi_col, BLOCK_SIZE bsize) {
-  const AV1_COMMON *const cm = &cpi->common;
+  AV1_COMMON *cm = &cpi->common;
   const CommonModeInfoParams *const mi_params = &cm->mi_params;
   MACROBLOCKD *const xd = &cpi->td.mb.e_mbd;
   assert(bsize < BLOCK_SIZES_ALL);
@@ -3863,32 +3863,6 @@
                                            frame_filter_dictionary, dict_stride
 #endif
           );
-#if CONFIG_COMBINE_PC_NS_WIENER
-          if (plane == AOM_PLANE_Y) {
-            // TODO: Needs to be fixed.
-            RestorationInfo *rsi = (RestorationInfo *)cm->rst_info + plane;
-            if (rsi->frame_filters_on && !rsi->frame_filters_initialized &&
-                rui->restoration_type == RESTORE_WIENER_NONSEP
-#if CONFIG_TEMP_LR
-                && !rsi->temporal_pred_flag
-#endif  // CONFIG_TEMP_LR
-            ) {
-              rsi->frame_filters_initialized = 1;
-              const WienerNonsepInfoBank *bank = &xd->wienerns_info[plane];
-              assert(bank->frame_filter_predictors_are_set);
-              rsi->frame_filters.num_classes = bank->filter[0].num_classes;
-              for (int c_id = 0; c_id < rsi->frame_filters.num_classes;
-                   ++c_id) {
-                copy_nsfilter_taps_for_class(
-                    &rsi->frame_filters,
-                    av1_constref_from_wienerns_bank(bank, 0, c_id), c_id);
-              }
-#if CONFIG_TEMP_LR
-              av1_copy_rst_frame_filters(&cm->cur_frame->rst_info[plane], rsi);
-#endif  // CONFIG_TEMP_LR
-            }
-          }
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
         }
       }
     }
@@ -4332,6 +4306,9 @@
             write_num_classes && NUM_WIENERNS_CLASS_INIT_LUMA > 1;
         if (write_num_classes) {
           aom_wb_write_literal(wb, rsi->frame_filters_on, 1);
+          // printf("Frame %d: frame_filters_on %d temporal_pred_flag %d\n",
+          //        cm->current_frame.order_hint, rsi->frame_filters_on,
+          //        rsi->temporal_pred_flag);
 #if CONFIG_TEMP_LR
           if (rsi->frame_filters_on) {
             const int num_ref_frames = cm->current_frame.frame_type == KEY_FRAME
@@ -4614,6 +4591,18 @@
       ref == AOMMAX(0, bank->bank_size_for_class[wiener_class_id] - 1)));
   return exact_match;
 }
+
+static int check_and_write_exact_match(
+    const WienerNonsepInfo *wienerns_info,
+    const WienerNonsepInfo *ref_wienerns_info,
+    const WienernsFilterParameters *nsfilter_params, int wiener_class_id,
+    MACROBLOCKD *xd, aom_writer *wb) {
+  const int exact_match =
+      check_wienerns_eq(wienerns_info, ref_wienerns_info,
+                        nsfilter_params->ncoeffs, wiener_class_id);
+  aom_write_symbol(wb, exact_match, xd->tile_ctx->merged_param_cdf, 2);
+  return exact_match;
+}
 #endif  // CONFIG_LR_MERGE_COEFFS
 
 #if CONFIG_COMBINE_PC_NS_WIENER
@@ -4633,16 +4622,117 @@
   (void)total_bits;
   assert(total_bits == count_bits);
 }
+
+static AOM_INLINE void write_wienerns_framefilters(
+    AV1_COMMON *cm, MACROBLOCKD *xd, int plane, aom_writer *wb,
+    int16_t *frame_filter_dictionary, int dict_stride) {
+  const int base_qindex = cm->quant_params.base_qindex;
+  RestorationInfo *rsi = &cm->rst_info[plane];
+  const int is_uv = plane > 0;
+  const int num_classes = rsi->num_filter_classes;
+  assert(rsi->frame_filters_on && !rsi->frame_filters_initialized);
+  assert(!is_uv);
+  const WienernsFilterParameters *nsfilter_params =
+      get_wienerns_parameters(base_qindex, plane != AOM_PLANE_Y);
+  int skip_filter_write_for_class[WIENERNS_MAX_CLASSES] = { 0 };
+#if CONFIG_LR_MERGE_COEFFS
+#if CONFIG_TEMP_LR
+  assert(!rsi->temporal_pred_flag);
+#endif  // CONFIG_TEMP_LR
+  write_match_indices(&rsi->frame_filters, wb);
+  WienerNonsepInfoBank bank = { 0 };
+  // needed to handle asserts in copy_nsfilter_taps_for_class
+  bank.filter[0].num_classes = num_classes;
+
+  fill_first_slot_of_bank_with_filter_match(
+      &bank, &rsi->frame_filters, rsi->frame_filters.match_indices, base_qindex,
+      ALL_WIENERNS_CLASSES, frame_filter_dictionary, dict_stride);
+  for (int c_id = 0; c_id < num_classes; ++c_id) {
+    skip_filter_write_for_class[c_id] = check_and_write_exact_match(
+        &rsi->frame_filters, av1_constref_from_wienerns_bank(&bank, 0, c_id),
+        nsfilter_params, c_id, xd, wb);
+  }
+#else
+  (void)xd;
+#endif  // CONFIG_LR_MERGE_COEFFS
+  assert(num_classes <= WIENERNS_MAX_CLASSES);
+  const int(*wienerns_coeffs)[WIENERNS_COEFCFG_LEN] = nsfilter_params->coeffs;
+
+  for (int c_id = 0; c_id < num_classes; ++c_id) {
+    if (skip_filter_write_for_class[c_id]) continue;
+    const WienerNonsepInfo *ref_wienerns_info =
+        av1_constref_from_wienerns_bank(&bank, 0, c_id);
+    const int16_t *wienerns_info_nsfilter =
+        const_nsfilter_taps(&rsi->frame_filters, c_id);
+    const int16_t *ref_wienerns_info_nsfilter =
+        const_nsfilter_taps(ref_wienerns_info, c_id);
+
+    const int beg_feat = 0;
+    int end_feat = nsfilter_params->ncoeffs;
+    if (end_feat > 6) {
+      // Decide whether to signal a short (0) or long (1) filter
+      int filter_length_bit = 0;
+      for (int i = 6; i < end_feat; i++) {
+        if (wienerns_info_nsfilter[i] != 0) {
+          filter_length_bit = 1;
+          break;
+        }
+      }
+      aom_write_symbol(wb, filter_length_bit,
+                       xd->tile_ctx->wienerns_length_cdf[is_uv], 2);
+      end_feat = filter_length_bit ? nsfilter_params->ncoeffs : 6;
+    }
+    assert((end_feat & 1) == 0);
+
+    int uv_sym = 0;
+    if (is_uv && end_feat > 6) {
+      uv_sym = 1;
+      for (int i = 6; i < end_feat; i += 2) {
+        if (wienerns_info_nsfilter[i + 1] != wienerns_info_nsfilter[i])
+          uv_sym = 0;
+      }
+      aom_write_symbol(wb, uv_sym, xd->tile_ctx->wienerns_uv_sym_cdf, 2);
+    }
+
+    for (int i = beg_feat; i < end_feat; ++i) {
+#if ENABLE_LR_4PART_CODE
+      aom_write_4part_wref(
+          wb,
+          ref_wienerns_info_nsfilter[i] -
+              wienerns_coeffs[i - beg_feat][WIENERNS_MIN_ID],
+          wienerns_info_nsfilter[i] -
+              wienerns_coeffs[i - beg_feat][WIENERNS_MIN_ID],
+          xd->tile_ctx->wienerns_4part_cdf[wienerns_coeffs[i - beg_feat]
+                                                          [WIENERNS_PAR_ID]],
+          wienerns_coeffs[i - beg_feat][WIENERNS_BIT_ID]);
+#else
+      aom_write_primitive_refsubexpfin(
+          wb, (1 << wienerns_coeffs[i - beg_feat][WIENERNS_BIT_ID]),
+          wienerns_coeffs[i - beg_feat][WIENERNS_PAR_ID],
+          ref_wienerns_info_nsfilter[i] -
+              wienerns_coeffs[i - beg_feat][WIENERNS_MIN_ID],
+          wienerns_info_nsfilter[i] -
+              wienerns_coeffs[i - beg_feat][WIENERNS_MIN_ID]);
+#endif  // ENABLE_LR_4PART_CODE
+      if (uv_sym && i >= 6) {
+        // Don't code symmetrical taps
+        assert(wienerns_info_nsfilter[i + 1] == wienerns_info_nsfilter[i]);
+        i += 1;
+      }
+    }
+  }
+  rsi->frame_filters_initialized = 1;
+#if CONFIG_TEMP_LR
+  av1_copy_rst_frame_filters(&cm->cur_frame->rst_info[plane], rsi);
+#endif  // CONFIG_TEMP_LR
+  return;
+}
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
 
 static AOM_INLINE void write_wienerns_filter(
-    MACROBLOCKD *xd, int plane, const WienerNonsepInfo *wienerns_info,
-    WienerNonsepInfoBank *bank, aom_writer *wb
-#if CONFIG_COMBINE_PC_NS_WIENER
-    ,
-    int base_qindex, int16_t *frame_filter_dictionary, int dict_stride
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
-) {
+    MACROBLOCKD *xd, int plane, const RestorationInfo *rsi,
+    const WienerNonsepInfo *wienerns_info, WienerNonsepInfoBank *bank,
+    aom_writer *wb) {
   const int is_uv = plane > 0;
   const WienernsFilterParameters *nsfilter_params =
       get_wienerns_parameters(xd->current_base_qindex, plane != AOM_PLANE_Y);
@@ -4651,33 +4741,11 @@
 #if CONFIG_LR_MERGE_COEFFS
 
 #if CONFIG_COMBINE_PC_NS_WIENER
-  const int skip_filter_write_all_classes =
-#if CONFIG_TEMP_LR
-      wienerns_info->temporal_pred_flag ||
-#endif  // CONFIG_TEMP_LR
-      (wienerns_info->frame_filters_on &&
-       bank->frame_filter_predictors_are_set);
-  if (
-#if CONFIG_TEMP_LR
-      !wienerns_info->temporal_pred_flag &&
-#endif  // CONFIG_TEMP_LR
-      wienerns_info->frame_filters_on &&
-      !bank->frame_filter_predictors_are_set) {
-    write_match_indices(wienerns_info, wb);
-    fill_first_slot_of_bank_with_filter_match(
-        bank, wienerns_info, wienerns_info->match_indices, base_qindex,
-        ALL_WIENERNS_CLASSES, frame_filter_dictionary, dict_stride);
-    bank->frame_filter_predictors_are_set = 1;
-  }
+  assert(IMPLIES(rsi->frame_filters_on, rsi->frame_filters_initialized));
+  if (rsi->frame_filters_on && rsi->frame_filters_initialized) return;
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
 
   for (int c_id = 0; c_id < wienerns_info->num_classes; ++c_id) {
-#if CONFIG_COMBINE_PC_NS_WIENER
-    if (skip_filter_write_all_classes) {
-      skip_filter_write_for_class[c_id] = 1;
-      continue;
-    }
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
     skip_filter_write_for_class[c_id] = check_and_write_merge_info(
         wienerns_info, bank, nsfilter_params, c_id, ref_for_class, xd, wb);
   }
@@ -4759,14 +4827,14 @@
 #endif  // CONFIG_LR_IMPROVEMENTS
 
 static AOM_INLINE void loop_restoration_write_sb_coeffs(
-    const AV1_COMMON *const cm, MACROBLOCKD *xd, const RestorationUnitInfo *rui,
+    AV1_COMMON *cm, MACROBLOCKD *xd, const RestorationUnitInfo *rui,
     aom_writer *const w, int plane, FRAME_COUNTS *counts
 #if CONFIG_COMBINE_PC_NS_WIENER
     ,
     int16_t *frame_filter_dictionary, int dict_stride
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
 ) {
-  const RestorationInfo *rsi = cm->rst_info + plane;
+  RestorationInfo *rsi = cm->rst_info + plane;
   RestorationType frame_rtype = rsi->frame_restoration_type;
   assert(frame_rtype != RESTORE_NONE);
 
@@ -4774,16 +4842,6 @@
   assert(!cm->features.all_lossless);
 
   const int wiener_win = (plane > 0) ? WIENER_WIN_CHROMA : WIENER_WIN;
-#if CONFIG_COMBINE_PC_NS_WIENER
-  if (plane == AOM_PLANE_Y && rsi->frame_filters_initialized) {
-    WienerNonsepInfoBank *bank = &xd->wienerns_info[plane];
-    bank->filter[0].num_classes = rsi->frame_filters.num_classes;
-    if (!bank->frame_filter_predictors_are_set) {
-      av1_add_to_wienerns_bank(bank, &rsi->frame_filters, ALL_WIENERNS_CLASSES);
-      bank->frame_filter_predictors_are_set = 1;
-    }
-  }
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
 
   RestorationType unit_rtype = rui->restoration_type;
 #if CONFIG_LR_IMPROVEMENTS
@@ -4821,13 +4879,13 @@
         break;
 #if CONFIG_LR_IMPROVEMENTS
       case RESTORE_WIENER_NONSEP:
-        write_wienerns_filter(
-            xd, plane, &rui->wienerns_info, &xd->wienerns_info[plane], w
 #if CONFIG_COMBINE_PC_NS_WIENER
-            ,
-            cm->quant_params.base_qindex, frame_filter_dictionary, dict_stride
+        if (rsi->frame_filters_on && !rsi->frame_filters_initialized)
+          write_wienerns_framefilters(cm, xd, plane, w, frame_filter_dictionary,
+                                      dict_stride);
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
-        );
+        write_wienerns_filter(xd, plane, rsi, &rui->wienerns_info,
+                              &xd->wienerns_info[plane], w);
         break;
       case RESTORE_PC_WIENER:
         // No side-information for now.
@@ -4863,15 +4921,12 @@
 #endif  // CONFIG_ENTROPY_STATS
     if (unit_rtype != RESTORE_NONE) {
 #if CONFIG_COMBINE_PC_NS_WIENER
-      assert(rui->wienerns_info.frame_filters_on == rsi->frame_filters_on);
+      if (rsi->frame_filters_on && !rsi->frame_filters_initialized)
+        write_wienerns_framefilters(cm, xd, plane, w, frame_filter_dictionary,
+                                    dict_stride);
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
-      write_wienerns_filter(
-          xd, plane, &rui->wienerns_info, &xd->wienerns_info[plane], w
-#if CONFIG_COMBINE_PC_NS_WIENER
-          ,
-          cm->quant_params.base_qindex, frame_filter_dictionary, dict_stride
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
-      );
+      write_wienerns_filter(xd, plane, rsi, &rui->wienerns_info,
+                            &xd->wienerns_info[plane], w);
     }
   } else if (frame_rtype == RESTORE_PC_WIENER) {
     aom_write_symbol(w, unit_rtype != RESTORE_NONE,
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index 338d89e..bb69af3 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -1240,7 +1240,6 @@
   rui->luma_stride = rsc->luma_stride;
   rui->plane = rsc->plane;
   rui->wienerns_info.num_classes = rsc->num_filter_classes;
-  rui->wienerns_info.frame_filters_on = rsc->frame_filters_on;
 #if CONFIG_COMBINE_PC_NS_WIENER
   rui->skip_pcwiener_filtering = 0;
 #endif  // CONFIG_COMBINE_PC_NS_WIENER
@@ -3601,10 +3600,8 @@
          rest_unit_idx_in_rutile);
   assert(unit_stats->plane == rsc->plane);
   assert(rusi->sse[RESTORE_NONE] == unit_stats->real_sse);
-  rusi->wienerns_info.frame_filters_on = 0;
 #if CONFIG_COMBINE_PC_NS_WIENER
   if (rsc->frame_filters_on && rsc->plane == AOM_PLANE_Y) {
-    rui.wienerns_info.frame_filters_on = 1;
     // Pick the best filter for this RU.
     rusi->sse[RESTORE_WIENER_NONSEP] = evaluate_frame_filter(rsc, limits, &rui);
 
@@ -4433,17 +4430,11 @@
   RestSearchCtxt *rsc = (RestSearchCtxt *)priv;
   const RestUnitSearchInfo *rusi = &rsc->rusi[rest_unit_idx];
   const RestorationInfo *rsi = &rsc->cm->rst_info[rsc->plane];
-#if CONFIG_COMBINE_PC_NS_WIENER
-  rsi->unit_info[rest_unit_idx].wienerns_info.frame_filters_on =
-      rsi->frame_filters_on;
-#endif  // CONFIG_COMBINE_PC_NS_WIENER
 
   copy_unit_info(rsi->frame_restoration_type, rusi,
                  &rsi->unit_info[rest_unit_idx], rsc);
 #if CONFIG_TEMP_LR
   assert(rsi->temporal_pred_flag == rsc->temporal_pred_flag);
-  rsi->unit_info[rest_unit_idx].wienerns_info.temporal_pred_flag =
-      rsi->temporal_pred_flag;
 #endif  // CONFIG_TEMP_LR
 }