#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <winenclave.h>

#include <algorithm> // For min
#include <cstddef>   // For size_t
#include <cstdint>
#include <cstring> // For memcpy
#include <exception>
#include <map>
#include <stdexcept> // For out_of_range
#include <string>
#include <string_view>
#include <tuple> // For apply
#include <type_traits>
#include <utility> // For move
#include <vector>

#include <jsc/vbs/interface.hpp>

using namespace jsc::detail;

#ifdef JSC_CRYPTO_OPENSSL
#include <jsc/detail/cipher_openssl.hpp>
using impl = cipher_impl<openssl_cipher_impl>;
#else
#include <jsc/detail/cipher_bcrypt.hpp>
using impl = cipher_impl<bcrypt_cipher_impl>;
#endif

#ifdef JSC_VBS
extern "C" __declspec(dllexport) const IMAGE_ENCLAVE_CONFIG __enclave_config{
    sizeof(IMAGE_ENCLAVE_CONFIG),
    IMAGE_ENCLAVE_MINIMUM_CONFIG_SIZE,
#ifdef NDEBUG
    0,
#else
    IMAGE_ENCLAVE_POLICY_DEBUGGABLE, // DO NOT SHIP DEBUGGABLE ENCLAVES TO
                                     // PRODUCTION
#endif
    0,
    0,
    0,
    {0xFE, 0xFE}, // family id
    {0x01, 0x01}, // image id
    0,            // version
    0,            // SVN
    0x10000000,   // size
    1,            // number of threads
    IMAGE_ENCLAVE_FLAG_PRIMARY_IMAGE,
};
#endif

extern "C" __declspec(dllexport) BOOL
DllMain([[maybe_unused]] _In_ HINSTANCE hinst_dll,
        [[maybe_unused]] _In_ int32_t dw_readon,
        [[maybe_unused]] _In_ void *lpv_reserved) noexcept {
  return TRUE;
}

using vbs::object_id_t;

struct Objects {
  std::map<object_id_t, impl::sym_key> symmetric_keys;

  std::map<object_id_t, impl::rsa_key> encrypting_keys;

  std::map<object_id_t, impl::rsa_key> signing_keys;

  std::map<object_id_t, impl::encrypt_ctx> encrypt_ctxs;

  std::map<object_id_t, impl::decrypt_ctx> decrypt_ctxs;

  object_id_t last_id;
};

static Objects objects;

template <typename Routine, auto f,
          typename = std::enable_if_t<
              std::is_same_v<decltype(f), typename Routine::signature(*)>>>
static uintptr_t invoke_type_erased(void *param) noexcept {
  auto tuple = reinterpret_cast<routine_args_with_error_t<Routine> *>(param);
  uintptr_t result;
  std::string error;
  try {
    if constexpr (std::is_void_v<routine_result_t<Routine>>) {
      std::apply(f, tuple->args);
      result = 0;
    } else {
      result = static_cast<uintptr_t>(std::apply(f, tuple->args));
    }
  } catch (std::exception const &e) {
    result = static_cast<uintptr_t>(-1);
    error = e.what();
  } catch (...) {
    result = static_cast<uintptr_t>(-1);
    error = "unknown error";
  }

  if (!error.empty() && *tuple->error_len) {
    auto error_len_out = std::min(error.size(), *tuple->error_len);
    *tuple->error_len = error_len_out;
    std::memcpy(tuple->error, error.data(), error_len_out);
  }
  return result;
}

template <typename T>
static T &get_object(std::map<object_id_t, T> &map, object_id_t id) {
  try {
    return map.at(id);
  } catch (std::out_of_range const &) {
    throw std::runtime_error{"Invalid key"};
  }
}

template <typename T>
static object_id_t insert_object(std::map<object_id_t, T> &map, T &&o) {
  auto id = ++objects.last_id;
  map.emplace(id, std::move(o));
  return id;
}

/** @returns true if successful; false if insufficient size. */
[[nodiscard]]
static bool copy_out(std::vector<uint8_t> const &src, uint8_t *dest,
                     size_t *dest_len) {
  if (src.size() > *dest_len) {
    *dest_len = src.size();
    return false;
  }
  std::memcpy(dest, src.data(), src.size());
  *dest_len = src.size();
  return true;
}

#define JSC_EXPORT(NAME)                                                       \
  extern "C" __declspec(dllexport) uintptr_t CALLBACK NAME(                    \
      _In_ void *param) noexcept {                                             \
    static_assert(std::string_view{#NAME} == ::jsc::detail::vbs::NAME::name);  \
    return invoke_type_erased<::jsc::detail::vbs::NAME, NAME##_impl>(param);   \
  }

static void free_object_impl(object_id_t id) {
  objects.symmetric_keys.erase(id);
  objects.encrypting_keys.erase(id);
  objects.signing_keys.erase(id);
  objects.encrypt_ctxs.erase(id);
  objects.decrypt_ctxs.erase(id);
}

JSC_EXPORT(free_object)

static object_id_t load_rsa_public_key_impl(char const *blob, size_t len) {
  auto key = impl::parse_rsa_public_key({blob, len});
  return insert_object(objects.encrypting_keys, std::move(key));
}

JSC_EXPORT(load_rsa_public_key)

static object_id_t load_rsa_private_key_impl(char const *blob, size_t len) {
  auto key = impl::parse_rsa_private_key({blob, len});
  return insert_object(objects.signing_keys, std::move(key));
}

JSC_EXPORT(load_rsa_private_key)

static object_id_t encrypt_init_impl(object_id_t rsa_key_id, uint8_t *enc_key,
                                     size_t *enc_key_len, uint8_t *iv,
                                     object_id_t *sym_key_id) {
  auto sym_key = impl::generate_sym_key();

  {
    auto &encrypting_key = get_object(objects.encrypting_keys, rsa_key_id);

    auto enc_key_result =
        impl::rsa_encrypt(encrypting_key, sym_key.data(), sym_key.size());

    if (!copy_out(enc_key_result, enc_key, enc_key_len))
      return static_cast<object_id_t>(-1);
  }

  auto [encrypt_ctx, iv_result] = impl::encrypt_init(sym_key);
  std::memcpy(iv, iv_result.data(), iv_result.size());

  *sym_key_id = insert_object(objects.symmetric_keys, std::move(sym_key));
  return insert_object(objects.encrypt_ctxs, std::move(encrypt_ctx));
}

JSC_EXPORT(encrypt_init)

static void encrypt_update_impl(object_id_t encryptor_id,
                                uint8_t const *plaintext, size_t plaintext_len,
                                uint8_t *ciphertext, size_t *ciphertext_len) {
  auto &encryptor = get_object(objects.encrypt_ctxs, encryptor_id);

  auto result = impl::encrypt_update(encryptor, plaintext, plaintext_len);

  if (!copy_out(result, ciphertext, ciphertext_len))
    return;
}

JSC_EXPORT(encrypt_update)

static void encrypt_finish_impl(object_id_t encryptor_id, uint8_t *ciphertext,
                                size_t *ciphertext_len, uint8_t *tag) {
  auto &encryptor = get_object(objects.encrypt_ctxs, encryptor_id);

  auto [result, tag_result] = impl::encrypt_finish(encryptor);

  if (!copy_out(result, ciphertext, ciphertext_len))
    return;

  std::memcpy(tag, tag_result.data(), tag_result.size());
}

JSC_EXPORT(encrypt_finish)

static object_id_t decrypt_init_impl(object_id_t key_id, uint8_t const *iv,
                                     uint8_t const *tag) {
  auto &key = get_object(objects.symmetric_keys, key_id);

  auto decryptor = impl::decrypt_init(key, iv, tag);

  return insert_object(objects.decrypt_ctxs, std::move(decryptor));
}

JSC_EXPORT(decrypt_init)

static void decrypt_update_impl(object_id_t decryptor_id,
                                uint8_t const *ciphertext,
                                size_t ciphertext_len, uint8_t *plaintext,
                                size_t *plaintext_len) {
  auto &decryptor = get_object(objects.decrypt_ctxs, decryptor_id);

  auto result = impl::decrypt_update(decryptor, ciphertext, ciphertext_len);

  if (!copy_out(result, plaintext, plaintext_len))
    return;
}

JSC_EXPORT(decrypt_update)

static void decrypt_finish_impl(object_id_t decryptor_id, uint8_t *plaintext,
                                size_t *plaintext_len) {
  auto &decryptor = get_object(objects.decrypt_ctxs, decryptor_id);

  auto result = impl::decrypt_finish(decryptor);

  if (!copy_out(result, plaintext, plaintext_len))
    return;
}

JSC_EXPORT(decrypt_finish)

static void rsa_sign_impl(object_id_t key_id, uint8_t const *data, size_t len,
                          uint8_t *out, size_t *out_len) {
  auto &key = get_object(objects.signing_keys, key_id);

  auto result = impl::rsa_sign(key, data, len);

  if (!copy_out(result, out, out_len))
    return;
}

JSC_EXPORT(rsa_sign)
