/*
 * Copyright (c) 2024, Alliance for Open Media. All rights reserved
 *
 * This source code is subject to the terms of the BSD 3-Clause Clear License
 * and the Alliance for Open Media Patent License 1.0. If the BSD 3-Clause Clear
 * License was not distributed with this source code in the LICENSE file, you
 * can obtain it at aomedia.org/license/software-license/bsd-3-c-c/.  If the
 * Alliance for Open Media Patent License 1.0 was not distributed with this
 * source code in the PATENTS file, you can obtain it at
 * aomedia.org/license/patent-license/.
 */

#include "av1/encoder/part_split_prune_tflite.h"

#include <cstdio>
#include <memory>
#include <mutex>

#include "common/tf_lite_includes.h"

#include "av1/encoder/simple_intrapred_tflite_model_128x128.h"
#include "av1/encoder/simple_intrapred_tflite_model_16x16.h"
#include "av1/encoder/simple_intrapred_tflite_model_32x32.h"
#include "av1/encoder/simple_intrapred_tflite_model_64x64.h"
#include "av1/encoder/sms_part_split_prune_tflite_model.h"

#if HAVE_FEXCEPT
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include <fenv.h>
#endif

typedef std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
    TfLiteDelegateType;

struct PartSplitContext {
  std::unique_ptr<tflite::Interpreter> model_128X128;
  std::unique_ptr<tflite::Interpreter> model_64X64;
  std::unique_ptr<tflite::Interpreter> model_32X32;
  std::unique_ptr<tflite::Interpreter> model_16X16;
  std::unique_ptr<tflite::Interpreter> model_inter_64x64;
  std::unique_ptr<tflite::Interpreter> model_inter_32x32;
  std::unique_ptr<tflite::Interpreter> model_inter_16x16;
  std::unique_ptr<tflite::Interpreter> model_inter_8x8;

  std::vector<TfLiteDelegateType> to_delete;
};

std::mutex tfliteMutex;

static std::unique_ptr<tflite::Interpreter> create_interpreter(
    unsigned char *model_def, std::vector<TfLiteDelegateType> &to_delete) {
  std::lock_guard<std::mutex> lock(tfliteMutex);
  tflite::LoggerOptions::SetMinimumLogSeverity(tflite::TFLITE_LOG_ERROR);
  tflite::Model *model = (tflite::Model *)tflite::GetModel(model_def);

  const int num_threads = 1;
  TfLiteXNNPackDelegateOptions xnnpack_options =
      TfLiteXNNPackDelegateOptionsDefault();
  xnnpack_options.num_threads = AOMMAX(num_threads, 1);
  TfLiteDelegateType xnnpack_delegate(
      TfLiteXNNPackDelegateCreate(&xnnpack_options),
      &TfLiteXNNPackDelegateDelete);

  tflite::MutableOpResolver resolver;
  RegisterSelectedOps(&resolver);

  tflite::InterpreterBuilder builder(model, resolver);
  tflite::ErrorReporter *reporter(tflite::DefaultErrorReporter());
  std::unique_ptr<tflite::Interpreter> interpreter;
  builder(&interpreter);
  if (interpreter->ModifyGraphWithDelegate(xnnpack_delegate.get()) !=
      kTfLiteOk) {
    reporter->Report("Failed at modifying graph with XNNPack delegate");
    exit(1);
  }

  if (interpreter->AllocateTensors() != kTfLiteOk) {
    reporter->Report("Failed at allocating tensors");
    exit(1);
  }

  to_delete.push_back(std::move(xnnpack_delegate));
  return interpreter;
}

static void ensure_tflite_init(void **context, MODEL_TYPE model_type) {
  assert(model_type != MODEL_OTHER);

  if (*context == nullptr) *context = new PartSplitContext();
  PartSplitContext *ctx = (PartSplitContext *)*context;
  switch (model_type) {
    case MODEL_128X128:
      if (!ctx->model_128X128) {
        ctx->model_128X128 = create_interpreter(
            a3_qp96_128_160_luma_BLOCK_128X128_intra_tflite, ctx->to_delete);
      }
      break;
    case MODEL_64X64:
      if (!ctx->model_64X64) {
        ctx->model_64X64 = create_interpreter(
            a3_qp96_128_160_luma_BLOCK_64X64_intra_tflite, ctx->to_delete);
      }
      break;
    case MODEL_32X32:
      if (!ctx->model_32X32) {
        ctx->model_32X32 = create_interpreter(
            a3_qp96_128_160_luma_BLOCK_32X32_intra_tflite, ctx->to_delete);
      }
      break;
    case MODEL_16X16:
      if (!ctx->model_16X16) {
        ctx->model_16X16 = create_interpreter(
            a3_qp96_128_160_luma_BLOCK_16X16_intra_tflite, ctx->to_delete);
      }
      break;
    case MODEL_INTER_64X64:
      if (!ctx->model_inter_64x64) {
        ctx->model_inter_64x64 = create_interpreter(
            sms_part_split_prune_tflite_model_bs12, ctx->to_delete);
      }
      break;
    case MODEL_INTER_32X32:
      if (!ctx->model_inter_32x32) {
        ctx->model_inter_32x32 = create_interpreter(
            sms_part_split_prune_tflite_model_bs9, ctx->to_delete);
      }
      break;
    case MODEL_INTER_16X16:
      if (!ctx->model_inter_16x16) {
        ctx->model_inter_16x16 = create_interpreter(
            sms_part_split_prune_tflite_model_bs6, ctx->to_delete);
      }
      break;
    case MODEL_INTER_8X8:
      if (!ctx->model_inter_8x8) {
        ctx->model_inter_8x8 = create_interpreter(
            sms_part_split_prune_tflite_model_bs3, ctx->to_delete);
      }
      break;
    default: break;
  }
}

extern "C" int av2_part_split_prune_tflite_params(MODEL_TYPE model_type,
                                                  int prune_level,
                                                  struct ModelParams *params) {
  assert(model_type != MODEL_OTHER);
  switch (model_type) {
    case MODEL_128X128:
      *params =
          a3_qp96_128_160_luma_BLOCK_128X128_intra_tflite_params[prune_level];
      break;
    case MODEL_64X64:
      *params =
          a3_qp96_128_160_luma_BLOCK_64X64_intra_tflite_params[prune_level];
      break;
    case MODEL_32X32:
      *params =
          a3_qp96_128_160_luma_BLOCK_32X32_intra_tflite_params[prune_level];
      break;
    case MODEL_16X16:
      *params =
          a3_qp96_128_160_luma_BLOCK_16X16_intra_tflite_params[prune_level];
      break;
    case MODEL_INTER_64X64:
      *params = sms_part_split_prune_tflite_model_params_bs12[prune_level];
      break;
    case MODEL_INTER_32X32:
      *params = sms_part_split_prune_tflite_model_params_bs9[prune_level];
      break;
    case MODEL_INTER_16X16:
      *params = sms_part_split_prune_tflite_model_params_bs6[prune_level];
      break;
    case MODEL_INTER_8X8:
      *params = sms_part_split_prune_tflite_model_params_bs3[prune_level];
      break;
    default: return -1;
  }
  return 0;
}

#if HAVE_FEXCEPT && CONFIG_DEBUG
#define FLOATING_POINT_DISABLE_EXCEPTIONS \
  const int float_excepts = fedisableexcept(FE_UNDERFLOW | FE_OVERFLOW);
#define FLOATING_POINT_RESTORE_EXCEPTIONS feenableexcept(float_excepts);
#else
#define FLOATING_POINT_DISABLE_EXCEPTIONS
#define FLOATING_POINT_RESTORE_EXCEPTIONS
#endif  // HAVE_FEXCEPT && CONFIG_DEBUG

// Simple intra ML TFLite based inference

extern "C" int av2_part_split_prune_tflite_exec(void **context,
                                                const float *ml_input,
                                                int input_len, float *ml_output,
                                                int output_len,
                                                MODEL_TYPE model_type) {
  assert(model_type != MODEL_OTHER);

  ensure_tflite_init(context, model_type);
  PartSplitContext *ctx = (PartSplitContext *)*context;
  tflite::Interpreter *interpreter;
  switch (model_type) {
    case MODEL_128X128: interpreter = ctx->model_128X128.get(); break;
    case MODEL_64X64: interpreter = ctx->model_64X64.get(); break;
    case MODEL_32X32: interpreter = ctx->model_32X32.get(); break;
    case MODEL_16X16: interpreter = ctx->model_16X16.get(); break;
    case MODEL_INTER_64X64: interpreter = ctx->model_inter_64x64.get(); break;
    case MODEL_INTER_32X32: interpreter = ctx->model_inter_32x32.get(); break;
    case MODEL_INTER_16X16: interpreter = ctx->model_inter_16x16.get(); break;
    case MODEL_INTER_8X8: interpreter = ctx->model_inter_8x8.get(); break;
    default: return -1;
  }
  tflite::ErrorReporter *reporter(tflite::DefaultErrorReporter());

  float *input = interpreter->typed_input_tensor<float>(0);
  for (int i = 0; i < input_len; i++) {
    input[i] = ml_input[i];
  }

  FLOATING_POINT_DISABLE_EXCEPTIONS
  auto status = interpreter->Invoke();
  FLOATING_POINT_RESTORE_EXCEPTIONS

  if (status != kTfLiteOk) {
    reporter->Report("Failed at invoke");
    exit(1);
  }

  float *output = interpreter->typed_output_tensor<float>(0);
  for (int i = 0; i < output_len; i++) {
    ml_output[i] = output[i];
  }
  return 0;
}

extern "C" void av2_part_split_prune_tflite_close(void **context) {
  PartSplitContext *ctx = (PartSplitContext *)*context;
  if (ctx != nullptr) delete ctx;
  *context = nullptr;
}
