#include <jsc/detail/cipher_bcrypt.hpp>

#include <cassert>
#include <cstddef> // For size_t
#include <cstdint>
#include <cstring>   // For memcpy
#include <stdexcept> // For runtime_error
#include <string>
#include <string_view>
#include <utility> // For move
#include <vector>

#include <bcrypt.h>

#include <jsc/detail/encoding.hpp>

namespace jsc {
namespace detail {

// unsigned long* is not compatible with unsigned int*
using uint32_t = unsigned long;

static std::runtime_error nt_error(char const *origin, int status) {
  return std::runtime_error{std::string{origin} + ": NTStatus " +
                            std::to_string(status)};
}

using impl = cipher_impl<bcrypt_cipher_impl>;

static BCRYPT_ALG_ptr open_algorithm(wchar_t const *provider) {
  void *alg;
  auto status =
      BCryptOpenAlgorithmProvider(&alg, provider,
                                  /*implementation=*/nullptr, /*flags=*/0);
  if (status || !alg)
    throw nt_error("BCryptOpenAlgorithmProvider", status);
  return BCRYPT_ALG_ptr{alg};
}

static BCryptKeyPair key_pair_from_blob(wchar_t const *type,
                                        uint8_t const *blob, size_t len) {
  auto alg = open_algorithm(BCRYPT_RSA_ALGORITHM);

  void *key;
  auto status = BCryptImportKeyPair(alg.get(), /*import_key=*/nullptr, type,
                                    &key, const_cast<uint8_t *>(blob),
                                    static_cast<uint32_t>(len),
                                    /*flags=*/0);
  if (status || !key)
    throw nt_error("BCryptImportKeyPair", status);
  return {std::move(alg), BCRYPT_KEY_ptr{key}};
}

template <>
impl::rsa_key impl::parse_rsa_public_key(std::string_view blob) {
  return key_pair_from_blob(BCRYPT_RSAPUBLIC_BLOB,
                            reinterpret_cast<uint8_t const *>(blob.data()),
                            blob.size());
}

template <>
impl::rsa_key impl::parse_rsa_private_key(std::string_view blob) {
  return key_pair_from_blob(BCRYPT_RSAPRIVATE_BLOB,
                            reinterpret_cast<uint8_t const *>(blob.data()),
                            blob.size());
}

struct EncryptFunction {
  static constexpr auto function = BCryptEncrypt;
  static constexpr auto name = "BCryptEncrypt";
};

struct DecryptFunction {
  static constexpr auto function = BCryptDecrypt;
  static constexpr auto name = "BCryptDecrypt";
};

template <typename Function>
std::vector<uint8_t> rsa_encrypt_decrypt(impl::rsa_key &key,
                                         uint8_t const *data, size_t len) {
  BCRYPT_OAEP_PADDING_INFO padding_info{
      .pszAlgId = BCRYPT_SHA256_ALGORITHM, .pbLabel = nullptr, .cbLabel = 0};

  // Determines needed length
  uint32_t result_len{};
  auto status =
      Function::function(key.key.get(), const_cast<uint8_t *>(data),
                         static_cast<uint32_t>(len), &padding_info,
                         /*iv=*/nullptr, /*iv_len=*/0, /*output=*/nullptr,
                         /*output_len=*/0, &result_len, BCRYPT_PAD_OAEP);
  if (status)
    throw nt_error(Function::name, status);

  std::vector<uint8_t> result(result_len);
  status = Function::function(key.key.get(), const_cast<uint8_t *>(data),
                              static_cast<uint32_t>(len), &padding_info,
                              /*iv=*/nullptr, /*iv_len=*/0, result.data(),
                              result_len, &result_len, BCRYPT_PAD_OAEP);
  if (status)
    throw nt_error(Function::name, status);
  if (result_len > result.size())
    throw std::runtime_error{std::string{Function::name} +
                             ": unexpected length"};

  return result;
}

template <>
std::vector<uint8_t> impl::rsa_encrypt(impl::rsa_key &key, uint8_t const *data,
                                       size_t len) {
  return rsa_encrypt_decrypt<EncryptFunction>(key, data, len);
}

template <>
std::vector<uint8_t> impl::rsa_decrypt(impl::rsa_key &key, uint8_t const *data,
                                       size_t len) {
  return rsa_encrypt_decrypt<DecryptFunction>(key, data, len);
}

template <>
std::vector<uint8_t> impl::rsa_sign(impl::rsa_key &key, uint8_t const *data,
                                    size_t len) {
  // Hash
  std::vector<uint8_t> digest(kSignatureDigestSize);
  {
    auto alg = open_algorithm(BCRYPT_SHA256_ALGORITHM);
    auto status =
        BCryptHash(alg.get(), /*secret=*/nullptr, /*secret_len=*/0,
                   const_cast<uint8_t *>(data), static_cast<uint32_t>(len),
                   digest.data(), kSignatureDigestSize);
    if (status)
      throw nt_error("BCryptHash", status);
  }

  // Sign
  std::vector<uint8_t> signature(kSignatureSize);
  BCRYPT_PSS_PADDING_INFO padding_info{.pszAlgId = BCRYPT_SHA256_ALGORITHM,
                                       .cbSalt = 32};
  uint32_t signature_len{};

  auto status = BCryptSignHash(key.key.get(), &padding_info, digest.data(),
                               kSignatureDigestSize, signature.data(),
                               kSignatureSize, &signature_len, BCRYPT_PAD_PSS);
  if (status)
    throw nt_error("BCryptSignHash", status);
  if (signature_len != kSignatureSize)
    throw std::runtime_error{"BCryptSignHash: unexpected length"};

  return signature;
}

static std::vector<uint8_t> random_vector(std::size_t len) {
  std::vector<uint8_t> result(len);
  auto status = BCryptGenRandom(
      /*algorithm=*/nullptr, result.data(), static_cast<uint32_t>(len),
      BCRYPT_USE_SYSTEM_PREFERRED_RNG);
  if (status)
    throw nt_error("BCryptGenRandom", status);
  return result;
}

template <>
impl::sym_key impl::generate_sym_key() {
  return random_vector(kAesKeySize);
}

static std::tuple<BCRYPT_ALG_ptr, BCRYPT_KEY_ptr,
                  BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO>
key_from_secret(std::vector<uint8_t> const &secret) {
  auto alg = open_algorithm(BCRYPT_AES_ALGORITHM);
  auto status =
      BCryptSetProperty(alg.get(), BCRYPT_CHAINING_MODE,
                        const_cast<uint8_t *>(reinterpret_cast<uint8_t const *>(
                            BCRYPT_CHAIN_MODE_GCM)),
                        sizeof(BCRYPT_CHAIN_MODE_GCM), 0);
  if (status)
    throw nt_error("BCryptSetProperty(BCRYPT_CHAINING_MODE)", status);

  void *key_raw;
  status = BCryptGenerateSymmetricKey(
      alg.get(), &key_raw, /*key_object=*/nullptr, /*key_object_len=*/0,
      const_cast<uint8_t *>(secret.data()),
      static_cast<uint32_t>(secret.size()), /*flags=*/0);
  if (status || !key_raw)
    throw nt_error("BCryptGenerateSymmetricKey", status);
  BCRYPT_KEY_ptr key{key_raw};

  BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO cipher_info;
  BCRYPT_INIT_AUTH_MODE_INFO(cipher_info);
  cipher_info.dwFlags = BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG;
  cipher_info.cbNonce = kAesIvSize;
  cipher_info.cbTag = kAesTagSize;
  cipher_info.cbMacContext = kAesTagSize;

  return {std::move(alg), std::move(key), std::move(cipher_info)};
}

template <>
std::tuple<impl::encrypt_ctx, std::vector<uint8_t>>
impl::encrypt_init(impl::sym_key &secret) {
  auto [alg, key, cipher_info] = key_from_secret(secret);

  auto iv = random_vector(kAesIvSize);
  std::vector<uint8_t> tag(kAesTagSize), tag_state(kAesTagSize), iv_clone(iv);

  cipher_info.pbNonce = iv.data();
  cipher_info.pbTag = tag.data();
  cipher_info.pbMacContext = tag_state.data();

  return {BCryptEncryptContext{
              std::move(alg), std::move(key), std::move(cipher_info),
              std::move(iv), std::move(tag),
              /*partial_block=*/std::vector<uint8_t>(kAesBlockSize),
              /*partial_block_len=*/0,
              /*block_state=*/std::vector<uint8_t>(kAesBlockSize),
              std::move(tag_state)},
          std::move(iv_clone)};
}

// Common implementation of encrypt and decrypt
template <typename Function>
static std::vector<uint8_t> encrypt_decrypt_update(impl::encrypt_ctx &ctx,
                                                   uint8_t const *plaintext,
                                                   size_t len) {
  // Fill partial block first
  if (ctx.partial_block_len + len < kAesBlockSize) {
    // Still smaller than a block
    std::memcpy(&ctx.partial_block[ctx.partial_block_len], plaintext, len);
    ctx.partial_block_len += len;
    return {};
  }

  std::vector<uint8_t> ciphertext(len + kAesBlockSize - 1);
  auto ciphertext_out = ciphertext.data();

  if (ctx.partial_block_len) {
    // Fill one block
    auto copied_len = kAesBlockSize - ctx.partial_block_len;
    std::memcpy(&ctx.partial_block[ctx.partial_block_len], plaintext,
                copied_len);
    plaintext += copied_len;
    len -= copied_len;
    ctx.partial_block_len = kAesBlockSize;

    // Encrypt one block
    uint32_t ciphertext_len{};
    auto status = Function::function(
        ctx.key.get(), ctx.partial_block.data(), kAesBlockSize,
        &ctx.cipher_info, ctx.block_state.data(), kAesBlockSize, ciphertext_out,
        kAesBlockSize, &ciphertext_len,
        /*flags=*/0);
    if (status)
      throw nt_error(Function::name, status);
    if (ciphertext_len != kAesBlockSize)
      throw std::runtime_error{std::string{Function::name} +
                               ": unexpected length"};
    ciphertext_out += ciphertext_len;
    ctx.partial_block_len = 0;
  }

  // Encrypt full blocks
  auto blocks = len / kAesBlockSize;
  if (blocks) {
    uint32_t ciphertext_len{};
    auto blocks_len = blocks * kAesBlockSize;
    auto status = Function::function(
        ctx.key.get(), const_cast<uint8_t *>(plaintext),
        static_cast<uint32_t>(blocks_len), &ctx.cipher_info,
        ctx.block_state.data(), kAesBlockSize, ciphertext_out,
        static_cast<uint32_t>(blocks_len), &ciphertext_len,
        /*flags=*/0);
    if (status)
      throw nt_error(Function::name, status);
    if (ciphertext_len != blocks_len)
      throw std::runtime_error{std::string{Function::name} +
                               ": unexpected length"};
    ciphertext_out += blocks_len;
    plaintext += blocks_len;
    len -= blocks_len;
  }
  assert(ciphertext_out < ciphertext.data() + ciphertext.size());
  ciphertext.resize(static_cast<size_t>(ciphertext_out - ciphertext.data()));

  // Save remaining
  assert(len < kAesBlockSize);
  if (len) {
    std::memcpy(ctx.partial_block.data(), plaintext, len);
    ctx.partial_block_len = static_cast<uint8_t>(len);
  }

  return ciphertext;
}

template <typename Function>
static std::vector<uint8_t> encrypt_decrypt_finish(impl::encrypt_ctx &ctx) {
  ctx.cipher_info.dwFlags &= ~BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG;

  std::vector<uint8_t> ciphertext(ctx.partial_block_len);
  uint32_t ciphertext_len{};
  // When there is no remaining data to encrypt, make sure the output parameter
  // to BCryptEncrypt is not null, so we can get the tag
  ciphertext.reserve(1);

  auto status = Function::function(
      ctx.key.get(), ctx.partial_block.data(), ctx.partial_block_len,
      &ctx.cipher_info, ctx.block_state.data(), kAesBlockSize,
      ciphertext.data(), ctx.partial_block_len, &ciphertext_len,
      /*flags=*/0);
  if (status)
    throw nt_error(Function::name, status);
  if (ciphertext_len != ctx.partial_block_len)
    throw std::runtime_error{std::string{Function::name} +
                             ": unexpected length"};

  return ciphertext;
}

template <>
std::vector<uint8_t> impl::encrypt_update(impl::encrypt_ctx &ctx,
                                          uint8_t const *plaintext,
                                          size_t len) {
  return encrypt_decrypt_update<EncryptFunction>(ctx, plaintext, len);
}

template <>
std::tuple<std::vector<uint8_t>, std::vector<uint8_t>>
impl::encrypt_finish(impl::encrypt_ctx &ctx) {
  auto ciphertext = encrypt_decrypt_finish<EncryptFunction>(ctx);
  return {std::move(ciphertext), std::move(ctx.tag)};
}

template <>
impl::decrypt_ctx impl::decrypt_init(impl::sym_key &secret, uint8_t const *iv,
                                     uint8_t const *tag) {
  auto [alg, key, cipher_info] = key_from_secret(secret);

  std::vector<uint8_t> iv_vec(iv, iv + kAesIvSize),
      tag_vec(tag, tag + kAesTagSize), tag_state(kAesTagSize);

  cipher_info.pbNonce = iv_vec.data();
  cipher_info.pbTag = tag_vec.data();
  cipher_info.pbMacContext = tag_state.data();

  return BCryptEncryptContext{
      std::move(alg),
      std::move(key),
      std::move(cipher_info),
      std::move(iv_vec),
      std::move(tag_vec),
      /*partial_block=*/std::vector<uint8_t>(kAesBlockSize),
      /*partial_block_len=*/0,
      /*block_state=*/std::vector<uint8_t>(kAesBlockSize),
      std::move(tag_state)};
}

template <>
std::vector<uint8_t> impl::decrypt_update(impl::decrypt_ctx &ctx,
                                          uint8_t const *ciphertext,
                                          size_t len) {
  return encrypt_decrypt_update<DecryptFunction>(ctx, ciphertext, len);
}

template <>
std::vector<uint8_t> impl::decrypt_finish(impl::decrypt_ctx &ctx) {
  return encrypt_decrypt_finish<DecryptFunction>(ctx);
}

} // namespace detail
} // namespace jsc
