# Copyright (c) 2019, Alliance for Open Media. All rights reserved
#
# This source code is subject to the terms of the BSD 2 Clause License and
# the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
# was not distributed with this source code in the LICENSE file, you can
# obtain it at www.aomedia.org/license/software. If the Alliance for Open
# Media Patent License 1.0 was not distributed with this source code in the
# PATENTS file, you can obtain it at www.aomedia.org/license/patent.
#

r"""Python module that transforms a Tensorflow ckpt to a C header file.

Usage: tf_ckpt_to_c_header.py [-h] [--input_path INPUT_PATH]
                              [--output_path OUTPUT_PATH]
                              [--header_guard HEADER_GUARD]
                              [--config_name CONFIG_NAME]
                              [--var_regex VAR_REGEX]
                              [--trained_qp TRAINED_QP]
                              [--is_residue IS_RESIDUE]
                              [--ext_width EXT_WIDTH]
                              [--ext_height EXT_HEIGHT]
                              [--strict_bounds STRICT_BOUNDS]
                              [--enable_explicit_field_names]
                              [--enable_aligned_declaration]
                              [--architecture {VDSR|WDSR}]

Optional Arguments:
  -h, --help            show this help message and exit
  --input_path INPUT_PATH
                        Path to ckpt. Please include the full prefix of all
                        relevant ckpt files.
  --output_path OUTPUT_PATH
                        Path to output of file.
  --header_guard HEADER_GUARD
                        Name of the header file header guard.
  --config_name CONFIG_NAME
                        Name of model and prefix to all relevant weight, bias,
                        and qp variables.
  --var_regex VAR_REGEX
                        Regex to match tensor names against in the model ckpt.
  --trained_qp TRAINED_QP
                        The primary QP of the model.
  --is_residue IS_RESIDUE
                        Whether we are predicting an image or the residue of
                        an image.
  --ext_width EXT_WIDTH
                        Width of the frame extension.
  --ext_height EXT_HEIGHT
                        Height of the frame extension.
  --strict_bounds STRICT_BOUNDS
                        Whether the input bounds are strict or not.
  --enable_explicit_field_names
                        Whether to print field names along side values in cnn
                        config.
  --enable_aligned_declaration
                        Whether to align all the weights and biases in memory
                        along multiples of four.
  --architecture {VDSR|WDSR}
                        Kind of architecture the network uses.

Example Invocation:
$ python tf_ckpt_to_c_header.py --input_path="./model.ckpt" --output_path=
    "./output.h" --header_guard="MODEL_H_" --trained_qp=32
    --enable_explicit_field_names
"""

import argparse
import collections
import copy
import re
import sys

from tf_ckpt_to_c_header_struct_parser import StructNode

from tensorflow import logging
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app

FLAGS = None
BRANCH_CONFIG_ORDER = StructNode((
    ("input_to_branches", "0x00"),
    ("copy_to_branches", 0),
    ("branches_to_combine", "0x00")), name="branch_config")
BATCHNORM_PARAMS_ORDER = StructNode((
    ("bn_gamma", None),
    ("bn_beta", None),
    ("bn_mean", None),
    ("bn_std", None)), name="batchnorm_params")
LAYER_CONFIG_ORDER = StructNode((
    ("in_channels", -1),
    ("filter_width", -1),
    ("filter_height", -1),
    ("out_channels", -1),
    ("skip_width", 1),
    ("skip_height", 1),
    ("maxpool", 0),
    ("weights", ""),
    ("bias", ""),
    ("pad", "PADDING_SAME_ZERO"),
    ("activation", "NONE"),
    ("deconvolve", 0),
    ("branch", 0),
    ("branch_copy_type", "BRANCH_NO_COPY"),
    ("branch_combine_type", "BRANCH_NOC"),
    ("branch_config", BRANCH_CONFIG_ORDER),
    ("bn_params", BATCHNORM_PARAMS_ORDER),
    ("output_num", -1)))
CNN_CONFIG_ORDER = {}
BIT_ALIGNMENT = 32
WEIGHT_STRING = "w"
BIAS_STRING = "b"

logging.set_verbosity("INFO")


def _get_weight_tensor_from_prefix(tensor_prefix):
  return tensor_prefix + WEIGHT_STRING


def _get_bias_tensor_from_prefix(tensor_prefix):
  return tensor_prefix + BIAS_STRING


def _generate_layer_index_tensor_name_map(input_reader, var_regex):
  """Create a map from tuples of indices to tensor names."""
  layer_index_tensor_name_map = {}
  # Create a index-tuple to tensor name map.
  for k, _ in input_reader.get_variable_to_shape_map().iteritems():
    match = re.match(var_regex, k)
    if match:
      # Value is the prefix of the layer variables.
      var_indices = tuple(int(x) for x in match.group(1).split("_")[:-1])
      # Ignore the (b|w) decorator since we cannot control the order they
      # are read.
      layer_index_tensor_name_map[var_indices] = k[:-1]
  logging.info("Gathered {0} variables: {1}".format(
      len(layer_index_tensor_name_map),
      [value for _, value in layer_index_tensor_name_map.iteritems()]))
  return layer_index_tensor_name_map


def _format_tensor(tensor, layer):
  """Reformats the tensor from Python-style array to C-style array."""
  flattened_tensor = tensor.flatten()
  if not sum(flattened_tensor):
    logging.warning("Tensor at layer %d is a zero tensor!", layer)
  parsed_tensor = " ".join([("%ff," % value) for value in flattened_tensor])
  return "{" + parsed_tensor + "}"


def _build_layer_config(shape, layer, num_layers):
  layer_config = copy.deepcopy(LAYER_CONFIG_ORDER)
  layer_config["weights"] = "%s_weight_%s" % (FLAGS.config_name, layer)
  layer_config["bias"] = "%s_bias_%s" % (FLAGS.config_name, layer)
  layer_config["filter_width"] = shape[0]
  layer_config["filter_height"] = shape[1]
  layer_config["in_channels"] = shape[2]
  layer_config["out_channels"] = shape[3]

  if layer == num_layers - 1:
    # Other layers get default which is -1.
    layer_config["output_num"] = 0

  if FLAGS.architecture == "WDSR" and 0 < layer and layer < num_layers - 1:
    branch_config = copy.deepcopy(BRANCH_CONFIG_ORDER)
    # If layer belongs to a residual block.
    if (layer % 3) == 1:
      # Input residual block layer.
      layer_config["activation"] = "RELU"
      layer_config["branch_copy_type"] = "BRANCH_INPUT"
      branch_config["input_to_branches"] = "0x02"
    elif (layer % 3) == 0:
      # Output residual block layer.
      layer_config["branch_combine_type"] = "BRANCH_ADD"
      branch_config["branches_to_combine"] = "0x02"
    layer_config["branch_config"] = branch_config
  elif FLAGS.architecture == "VDSR" and layer < num_layers - 1:
    layer_config["activation"] = "RELU"
  return layer_config


def _extract_layer_variables(input_reader):
  """Extracts the graph, weights, and biases from input_file."""
  cnn_config = CNN_CONFIG_ORDER
  layer_index_tensor_name_map = _generate_layer_index_tensor_name_map(
      input_reader, FLAGS.var_regex)
  num_layers = len(layer_index_tensor_name_map)
  cnn_config["num_layers"] = num_layers

  cnn_config["layer_config"] = [None] * num_layers
  weights = [None] * num_layers
  biases = [None] * num_layers

  layer = 0
  # By the end of this process, every layer should have weights and biases.
  for _, tensor_name in sorted(layer_index_tensor_name_map.iteritems()):
    cnn_config["layer_config"][layer] = _build_layer_config(
        input_reader.get_variable_to_shape_map()[
            _get_weight_tensor_from_prefix(tensor_name)],
        layer,
        num_layers)
    weights[layer] = _format_tensor(
        input_reader.get_tensor(_get_weight_tensor_from_prefix(tensor_name)),
        layer)
    biases[layer] = _format_tensor(
        input_reader.get_tensor(_get_bias_tensor_from_prefix(tensor_name)),
        layer)
    layer += 1
  return cnn_config, weights, biases


def _print_header(header_guard, output_file):
  """Writes the header of the header file. Contains the aom license."""
  output_file.write("""/*
 * Copyright (c) 2019, Alliance for Open Media. All rights reserved
 *
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
 */
""")
  output_file.write("\n#ifndef %s" % header_guard)
  output_file.write("\n#define %s\n" % header_guard)
  if FLAGS.enable_aligned_declaration:
    output_file.write("\n#include \"aom_ports/mem.h\"")
  output_file.write("\n#include \"av1/common/cnn.h\"\n")


def _print_weights_biases(layers, weights, biases, output_file):
  """Write the contents of the cnn to the header file.

  Each variable has static const prepended to them to ensure they remain
  local to the header and any file included them and need not be modified.
  All variables have FLAGS.config_name prepended to their name to distinguish
  them from other variables in other model header files.

  The layout of the file is as follows:
    - qp of the model,
    - weights and biases in alternating order,
    - the graph.

  The variables printed in the graph are determined by CNN_CONFIG_ORDER
  and LAYER_CONFIG_ORDER. It is especially important to maintain the orders in
  each of these constants because C expects that any initialized struct
  maintains the order of the fields it has in its declaration.

  Args:
    cnn_config: A dictionary containing all the graph parameters.
    weights: An array containing the weights of each layer in order.
    biases: An array containing the bias of each layer in order.
    output_file: The output file.
  """
  output_file.write("\nstatic const int %s_trained_qp = %d;\n" %
                    (FLAGS.config_name, FLAGS.trained_qp))
  if FLAGS.enable_aligned_declaration:
    for layer in range(layers):
      output_file.write("\nDECLARE_ALIGNED(%d, static float, " \
                        "%s_weight_%d[]) = %s;\n" %
                        (BIT_ALIGNMENT, FLAGS.config_name, layer,
                         weights[layer]))
      output_file.write("\nDECLARE_ALIGNED(%d, static float, " \
                        "%s_bias_%d[]) = %s;\n" %
                        (BIT_ALIGNMENT, FLAGS.config_name, layer,
                         biases[layer]))
  else:
    for layer in range(layers):
      output_file.write("\nstatic float %s_weight_%d[] = %s;\n" %
                        (FLAGS.config_name, layer, weights[layer]))
      output_file.write("\nstatic float %s_bias_%d[] = %s;\n" %
                        (FLAGS.config_name, layer, biases[layer]))


def _print_cnn_config(cnn_config, output_file):
  output_file.write("\nconst CNN_CONFIG %s = {" % FLAGS.config_name)
  cnn_config.write_fields_to_output(output_file,
                                    FLAGS.enable_explicit_field_names)
  output_file.write("\n};\n")


def _print_footer(header_guard, output_file):
  """Writes the footer of the header file."""
  output_file.write("\n#endif  // %s" % header_guard)


def generate_header_file():
  """Generates a C header file from a Tensorflow model ckpt."""
  try:
    input_reader = pywrap_tensorflow.NewCheckpointReader(FLAGS.input_path)
  except Exception as e:
    raise e
  cnn_config, weights, biases = _extract_layer_variables(input_reader)

  assert not [weight for weight in weights if weight is None]
  assert not [bias for bias in biases if bias is None]
  assert not [
      config for config in cnn_config["layer_config"] if config is None
  ]

  with open(FLAGS.output_path, "w+") as output_file:
    _print_header(FLAGS.header_guard, output_file)
    _print_weights_biases(cnn_config["num_layers"], weights, biases,
                          output_file)
    _print_cnn_config(cnn_config, output_file)
    _print_footer(FLAGS.header_guard, output_file)

    output_file.close()


def main(_):
  generate_header_file()


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument(
      "--input_path",
      type=str,
      default="",
      help="Path to ckpt. Please include the full prefix of all relevant ckpt "
      " files.")
  parser.add_argument(
      "--output_path",
      type=str,
      default="",
      help="Path to output of file.")
  parser.add_argument(
      "--header_guard",
      type=str,
      default="AOM_AV1_COMMON_MODEL_H_",
      help="Name of the header file header guard.")
  parser.add_argument(
      "--config_name",
      type=str,
      default="model",
      help="Name of model and prefix to all relevant weight, bias, and qp "
      "variables.")
  parser.add_argument(
      "--var_regex",
      type=str,
      default=r".*conv_(([0-9][0-9]*_)*)(w|b)$",
      help="Regex to match tensor names against in the model ckpt.")
  parser.add_argument(
      "--trained_qp",
      type=int,
      default=32,
      help="The primary QP of the model.")
  parser.add_argument(
      "--is_residue",
      type=int,
      default=1,
      help="Whether we are predicting an image or the residue of an image.")
  parser.add_argument(
      "--ext_width",
      type=int,
      default=0,
      help="Width of the frame extension.")
  parser.add_argument(
      "--ext_height",
      type=int,
      default=0,
      help="Height of the frame extension.")
  parser.add_argument(
      "--strict_bounds",
      type=int,
      default=0,
      help="Whether the input bounds are strict or not.")
  parser.add_argument(
      "--enable_explicit_field_names",
      default=False,
      action="store_true",
      help="Whether to print field names along side values in cnn config.")
  parser.add_argument(
      "--enable_aligned_declaration",
      default=False,
      action="store_true",
      help="Whether to align weights and biases in memory along addresses "
           "that are multiples of four.")
  parser.add_argument(
      "--architecture",
      default=None,
      choices=["WDSR", "VDSR"],
      help="Kind of architecture the network uses.")
  FLAGS, unparsed = parser.parse_known_args()

  if FLAGS.architecture:
    logging.info("Using %s architecture.", FLAGS.architecture)
  else:
    logging.info("Using default architecture. You will need to fill network "
                 "details manually.")

  # The order of these fields must reflect the order of the structs in
  # ${AOM_ROOT}/av1/common/cnn.h.

  CNN_CONFIG_ORDER = StructNode((
      ("num_layers", 0),
      ("is_residue", FLAGS.is_residue),
      ("ext_width", FLAGS.ext_width),
      ("ext_height", FLAGS.ext_height),
      ("strict_bounds", FLAGS.strict_bounds),
      ("layer_config", [])
  ))

  app.run(main=main, argv=[sys.argv[0]] + unparsed)
