|  | """ | 
|  | Copyright (c) 2024, Alliance for Open Media. All rights reserved | 
|  |  | 
|  | This source code is subject to the terms of the BSD 3-Clause Clear License | 
|  | and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear | 
|  | License was not distributed with this source code in the LICENSE file, you | 
|  | can obtain it at aomedia.org/license/software-license/bsd-3-c-c/.  If the | 
|  | Alliance for Open Media Patent License 1.0 was not distributed with this | 
|  | source code in the PATENTS file, you can obtain it at | 
|  | aomedia.org/license/patent-license/. | 
|  | """ | 
|  | import parakit.config.user as user | 
|  | from parakit.entropy.codec_default_cdf import get_aom_cdf_string | 
|  |  | 
|  | ADD_COMMENTS = False | 
|  |  | 
|  |  | 
|  | class EntropyContext: | 
|  | def __init__( | 
|  | self, | 
|  | ctx_group_name, | 
|  | num_symb, | 
|  | num_dims, | 
|  | size_list, | 
|  | model_dict, | 
|  | user_cfg_file=(), | 
|  | ): | 
|  | self.ctx_group_name = ctx_group_name | 
|  | self.num_symb = num_symb | 
|  | self.num_dims = num_dims | 
|  | self.size_list = size_list | 
|  | self.model_dict = model_dict | 
|  | self.user_cfg = user_cfg_file | 
|  |  | 
|  | def get_model_dictionary(self, index_list): | 
|  | key = self._get_key(index_list) | 
|  | return self.model_dict[key] | 
|  |  | 
|  | def get_model_string(self, index_list): | 
|  | return self.get_cdf_string(index_list) + ", " + self.get_rate_string(index_list) | 
|  |  | 
|  | def get_rate_string(self, index_list): | 
|  | model = self.get_model_dictionary(index_list) | 
|  | rate = model["init_rate"] | 
|  | rate_str = str(rate).rjust(3) | 
|  | return rate_str | 
|  |  | 
|  | def get_cdf_string(self, index_list): | 
|  | model = self.get_model_dictionary(index_list) | 
|  | cdf_list = model["initializer"] | 
|  | num_symb = len(cdf_list) | 
|  | cdf = cdf_list[: num_symb - 1] | 
|  | cdf_str = get_aom_cdf_string(cdf) | 
|  | return cdf_str | 
|  |  | 
|  | def get_frequency_string(self, index_list): | 
|  | model = self.get_model_dictionary(index_list) | 
|  | num_samples = model["num_samples"] | 
|  | sample_str = str(num_samples).rjust(8) | 
|  | return sample_str | 
|  |  | 
|  | def get_complete_model_string(self, is_avm_style=True): | 
|  | model_str = [] | 
|  | if is_avm_style: | 
|  | model_str = self._get_avm_variable_name() + " = " | 
|  | else: | 
|  | model_str = self._get_arrayname() + " = " | 
|  | num_dim = self.num_dims | 
|  | if num_dim == 0: | 
|  | model_str = self._model_string_dim0(model_str) | 
|  | elif num_dim == 1: | 
|  | model_str = self._model_string_dim1(model_str) | 
|  | elif num_dim == 2: | 
|  | model_str = self._model_string_dim2(model_str) | 
|  | elif num_dim == 3: | 
|  | model_str = self._model_string_dim3(model_str) | 
|  | elif num_dim == 4: | 
|  | model_str = self._model_string_dim4(model_str) | 
|  | else: | 
|  | model_str = "Number of dimensions needs to be between 0 and 4." | 
|  |  | 
|  | # last characters | 
|  | model_str += "\n" | 
|  |  | 
|  | return model_str | 
|  |  | 
|  | def _get_num_ctx_groups(self): | 
|  | num_context_groups = 1 | 
|  | for size in self.size_list: | 
|  | num_context_groups *= size | 
|  | return num_context_groups | 
|  |  | 
|  | def _get_key(self, index_list): | 
|  | key = self.ctx_group_name | 
|  | for index in index_list: | 
|  | key = key + "[" + str(index) + "]" | 
|  | return key | 
|  |  | 
|  | def _get_arrayname(self): | 
|  | array_name = self.ctx_group_name | 
|  | for size in self.size_list: | 
|  | array_name = array_name + "[" + str(size) + "]" | 
|  | array_name += "[CDF_SIZE(" + str(self.num_symb) + ")]" | 
|  | return array_name | 
|  |  | 
|  | def _get_avm_variable_name(self): | 
|  | array_name = "static const aom_cdf_prob " + user.read_config_context( | 
|  | self.user_cfg, self.ctx_group_name | 
|  | ) | 
|  | for size in self.size_list: | 
|  | array_name = array_name + "[" + str(size) + "]" | 
|  | array_name += "[CDF_SIZE(" + str(self.num_symb) + ")]" | 
|  | return array_name | 
|  |  | 
|  | def _model_string_dim0(self, model_str, is_commented=ADD_COMMENTS): | 
|  | end_char = ";" | 
|  | index_list = [] | 
|  | comment_str = ( | 
|  | "   // " | 
|  | + self._get_key(index_list) | 
|  | + " " | 
|  | + self.get_frequency_string(index_list) | 
|  | if is_commented | 
|  | else "" | 
|  | ) | 
|  | model_str = ( | 
|  | model_str | 
|  | + "{ " | 
|  | + self.get_model_string(index_list) | 
|  | + " }" | 
|  | + end_char | 
|  | + comment_str | 
|  | ) | 
|  | return model_str | 
|  |  | 
|  | def _model_string_dim1_special_nobracket( | 
|  | self, model_str, is_commented=ADD_COMMENTS | 
|  | ): | 
|  | for i0 in range(self.size_list[0]): | 
|  | index_list = [i0] | 
|  | comment_str = ( | 
|  | "   // " | 
|  | + self._get_key(index_list) | 
|  | + " " | 
|  | + self.get_frequency_string(index_list) | 
|  | if is_commented | 
|  | else "" | 
|  | ) | 
|  | model_str = ( | 
|  | model_str | 
|  | + "{ " | 
|  | + self.get_model_string(index_list) | 
|  | + " }," | 
|  | + comment_str | 
|  | + "\n" | 
|  | ) | 
|  | return model_str | 
|  |  | 
|  | def _model_string_dim1(self, model_str, is_commented=ADD_COMMENTS): | 
|  | end_char = ";" | 
|  | model_str += "{\n" | 
|  | for i0 in range(self.size_list[0]): | 
|  | index_list = [i0] | 
|  | comment_str = ( | 
|  | "   // " | 
|  | + self._get_key(index_list) | 
|  | + " " | 
|  | + self.get_frequency_string(index_list) | 
|  | if is_commented | 
|  | else "" | 
|  | ) | 
|  | model_str = ( | 
|  | model_str | 
|  | + "  { " | 
|  | + self.get_model_string(index_list) | 
|  | + " }," | 
|  | + comment_str | 
|  | + "\n" | 
|  | ) | 
|  | model_str += "}" + end_char | 
|  | return model_str | 
|  |  | 
|  | def _model_string_dim2(self, model_str, is_commented=ADD_COMMENTS): | 
|  | end_char = ";" | 
|  | model_str += "{\n" | 
|  | for i0 in range(self.size_list[0]): | 
|  | model_str += "  {\n" | 
|  | for i1 in range(self.size_list[1]): | 
|  | index_list = [i0, i1] | 
|  | comment_str = ( | 
|  | "   // " | 
|  | + self._get_key(index_list) | 
|  | + " " | 
|  | + self.get_frequency_string(index_list) | 
|  | if is_commented | 
|  | else "" | 
|  | ) | 
|  | model_str = ( | 
|  | model_str | 
|  | + "    { " | 
|  | + self.get_model_string(index_list) | 
|  | + " }," | 
|  | + comment_str | 
|  | + "\n" | 
|  | ) | 
|  | model_str += "  },\n" | 
|  | model_str += "}" + end_char  # + '\n' | 
|  | return model_str | 
|  |  | 
|  | def _model_string_dim3(self, model_str, is_commented=ADD_COMMENTS): | 
|  | end_char = ";" | 
|  | model_str += "{\n" | 
|  | for i0 in range(self.size_list[0]): | 
|  | model_str += "  {\n" | 
|  | for i1 in range(self.size_list[1]): | 
|  | model_str += "    {\n" | 
|  | for i2 in range(self.size_list[2]): | 
|  | index_list = [i0, i1, i2] | 
|  | comment_str = ( | 
|  | "   // " | 
|  | + self._get_key(index_list) | 
|  | + " " | 
|  | + self.get_frequency_string(index_list) | 
|  | if is_commented | 
|  | else "" | 
|  | ) | 
|  | model_str = ( | 
|  | model_str | 
|  | + "      { " | 
|  | + self.get_model_string(index_list) | 
|  | + " }," | 
|  | + comment_str | 
|  | + "\n" | 
|  | ) | 
|  | model_str += "    },\n" | 
|  | model_str += "  },\n" | 
|  | model_str += "}" + end_char  # + '\n' | 
|  | return model_str | 
|  |  | 
|  | def _model_string_dim4(self, model_str, is_commented=ADD_COMMENTS): | 
|  | end_char = ";" | 
|  | model_str += "{\n" | 
|  | for i0 in range(self.size_list[0]): | 
|  | model_str += "  {\n" | 
|  | for i1 in range(self.size_list[1]): | 
|  | model_str += "    {\n" | 
|  | for i2 in range(self.size_list[2]): | 
|  | model_str += "      {\n" | 
|  | for i3 in range(self.size_list[3]): | 
|  | index_list = [i0, i1, i2, i3] | 
|  | comment_str = ( | 
|  | "   // " | 
|  | + self._get_key(index_list) | 
|  | + " " | 
|  | + self.get_frequency_string(index_list) | 
|  | if is_commented | 
|  | else "" | 
|  | ) | 
|  | model_str = ( | 
|  | model_str | 
|  | + "        { " | 
|  | + self.get_model_string(index_list) | 
|  | + " }," | 
|  | + comment_str | 
|  | + "\n" | 
|  | ) | 
|  | model_str += "      },\n" | 
|  | model_str += "    },\n" | 
|  | model_str += "  },\n" | 
|  | model_str += "}" + end_char  # + '\n' | 
|  | return model_str |