#include <jsc/detail/cipher_openssl.hpp>

#include <cassert>
#include <cstddef> // For size_t
#include <cstdint>
#include <stdexcept> // For runtime_error
#include <string_view>
#include <utility> // For move
#include <vector>

#include <openssl/bio.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/rand.h>
#include <openssl/rsa.h>

#include <jsc/detail/encoding.hpp>
#include <jsc/detail/openssl_utils.hpp>

namespace jsc {
namespace detail {

using impl = cipher_impl<openssl_cipher_impl>;

template <>
impl::rsa_key impl::parse_rsa_public_key(std::string_view pem) {
  // This function is shared with host and implemented elsewhere
  return detail::parse_rsa_public_key(pem);
}

template <>
impl::rsa_key impl::parse_rsa_private_key(std::string_view pem) {
  BIO_ptr source{BIO_new_mem_buf(pem.data(), static_cast<int32_t>(pem.size()))};
  if (!source)
    throw openssl_error("BIO_new_mem_buf");

  EVP_PKEY_ptr pkey{
      PEM_read_bio_PrivateKey(&*source, nullptr, nullptr, nullptr)};
  if (!pkey)
    throw openssl_error("PEM_read_bio_PrivateKey");

  if (!EVP_PKEY_is_a(&*pkey, "RSA"))
    throw std::runtime_error{"Key is not RSA"};

  return pkey;
}

template <>
std::vector<uint8_t> impl::rsa_encrypt(impl::rsa_key &key, uint8_t const *data,
                                       size_t len) {
  EVP_PKEY_CTX_ptr pkey_ctx{EVP_PKEY_CTX_new(&*key, nullptr)};
  if (!pkey_ctx)
    throw openssl_error("EVP_PKEY_CTX_new");
  if (EVP_PKEY_encrypt_init(&*pkey_ctx) != 1)
    throw openssl_error("EVP_PKEY_encrypt_init");

  if (EVP_PKEY_CTX_set_rsa_padding(&*pkey_ctx, RSA_PKCS1_OAEP_PADDING) != 1)
    throw openssl_error("EVP_PKEY_CTX_set_rsa_padding");

  auto sha256 = EVP_sha256();
  if (EVP_PKEY_CTX_set_rsa_mgf1_md(&*pkey_ctx, &*sha256) != 1)
    throw openssl_error("EVP_PKEY_CTX_set_rsa_mgf1_md");
  if (EVP_PKEY_CTX_set_rsa_oaep_md(&*pkey_ctx, &*sha256) != 1)
    throw openssl_error("EVP_PKEY_CTX_set_rsa_oaep_md");

  size_t result_len;
  // Determines needed length
  if (EVP_PKEY_encrypt(&*pkey_ctx, nullptr, &result_len, data, len) != 1)
    throw openssl_error("EVP_PKEY_encrypt");

  std::vector<uint8_t> result(result_len);
  if (EVP_PKEY_encrypt(&*pkey_ctx, result.data(), &result_len, data, len) != 1)
    throw openssl_error("EVP_PKEY_encrypt");
  if (result_len > result.size())
    throw std::runtime_error{"EVP_PKEY_encrypt: unexpected length"};
  result.resize(result_len);

  return result;
}

template <>
std::vector<uint8_t> impl::rsa_decrypt(impl::rsa_key &key, uint8_t const *data,
                                       size_t len) {
  EVP_PKEY_CTX_ptr pkey_ctx{EVP_PKEY_CTX_new(&*key, nullptr)};
  if (!pkey_ctx)
    throw openssl_error("EVP_PKEY_CTX_new");
  if (EVP_PKEY_decrypt_init(&*pkey_ctx) != 1)
    throw openssl_error("EVP_PKEY_decrypt_init");

  if (EVP_PKEY_CTX_set_rsa_padding(&*pkey_ctx, RSA_PKCS1_OAEP_PADDING) != 1)
    throw openssl_error("EVP_PKEY_CTX_set_rsa_padding");

  auto sha256 = EVP_sha256();
  if (EVP_PKEY_CTX_set_rsa_mgf1_md(&*pkey_ctx, &*sha256) != 1)
    throw openssl_error("EVP_PKEY_CTX_set_rsa_mgf1_md");
  if (EVP_PKEY_CTX_set_rsa_oaep_md(&*pkey_ctx, &*sha256) != 1)
    throw openssl_error("EVP_PKEY_CTX_set_rsa_oaep_md");

  size_t result_len;
  // Determines needed length
  if (EVP_PKEY_decrypt(&*pkey_ctx, nullptr, &result_len, data, len) != 1)
    throw openssl_error("EVP_PKEY_decrypt");

  std::vector<uint8_t> result(result_len);
  if (EVP_PKEY_decrypt(&*pkey_ctx, result.data(), &result_len, data, len) != 1)
    throw openssl_error("EVP_PKEY_decrypt");
  if (result_len > result.size())
    throw std::runtime_error{"EVP_PKEY_decrypt: unexpected length"};
  result.resize(result_len);

  return result;
}

template <>
std::vector<uint8_t> impl::rsa_sign(impl::rsa_key &key, uint8_t const *data,
                                    size_t len) {
  EVP_MD_CTX_ptr md{EVP_MD_CTX_new()};
  if (!md)
    throw openssl_error("EVP_MD_CTX_new");

  EVP_PKEY_CTX *pkey_ctx;
  if (EVP_DigestSignInit(&*md, &pkey_ctx, EVP_sha256(), nullptr, &*key) != 1)
    throw openssl_error("EVP_DigestSignInit");

  if (EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING) <= 0)
    throw openssl_error("EVP_PKEY_CTX_set_rsa_padding");

  auto sha256 = EVP_sha256();
  if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, &*sha256) <= 0)
    throw openssl_error("EVP_PKEY_CTX_set_rsa_mgf1_md");
  if (EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, RSA_PSS_SALTLEN_MAX) <= 0)
    throw openssl_error("EVP_PKEY_CTX_set_rsa_pss_saltlen");

  if (EVP_DigestSignUpdate(&*md, data, len) != 1)
    throw openssl_error("EVP_DigestSignUpdate");

  auto signature_len = kSignatureSize;
  std::vector<uint8_t> signature(signature_len);
  if (EVP_DigestSignFinal(&*md, signature.data(), &signature_len) != 1)
    throw openssl_error("EVP_DigestSignFinal");
  if (signature_len != kSignatureSize)
    throw std::runtime_error{"EVP_DigestSignFinal: unexpected length"};

  return signature;
}

template <>
impl::sym_key impl::generate_sym_key() {
  std::vector<uint8_t> key(kAesKeySize);
  if (RAND_priv_bytes(key.data(), kAesKeySize) != 1)
    throw openssl_error("RAND_priv_bytes");
  return key;
}

template <>
std::tuple<impl::encrypt_ctx, std::vector<uint8_t>>
impl::encrypt_init(impl::sym_key &secret) {
  auto cipher_mode = EVP_aes_256_gcm();

  assert(kAesKeySize == EVP_CIPHER_get_key_length(cipher_mode));
  assert(kAesIvSize == EVP_CIPHER_get_iv_length(cipher_mode));
  assert(1 == EVP_CIPHER_get_block_size(cipher_mode));

  if (secret.size() != kAesKeySize)
    throw std::runtime_error{"encrypt_init: wrong length"};

  EVP_CIPHER_CTX_ptr ctx{EVP_CIPHER_CTX_new()};

  std::vector<uint8_t> iv(kAesIvSize);
  if (RAND_bytes(iv.data(), kAesIvSize) != 1)
    throw openssl_error("RAND_bytes");

  if (EVP_EncryptInit_ex2(&*ctx, cipher_mode, secret.data(), iv.data(),
                          nullptr) != 1)
    throw openssl_error("EVP_EncryptInit");

  return {std::move(ctx), std::move(iv)};
}

template <>
std::vector<uint8_t> impl::encrypt_update(impl::encrypt_ctx &ctx,
                                          uint8_t const *plaintext,
                                          size_t len) {
  // Ciphertext is of the same length with plaintext in stream cipher
  std::vector<uint8_t> ciphertext(len);

  int32_t ciphertext_len{};
  if (EVP_EncryptUpdate(&*ctx, ciphertext.data(), &ciphertext_len, plaintext,
                        static_cast<int32_t>(len)) != 1)
    throw openssl_error("EVP_EncryptUpdate");
  if (ciphertext_len != static_cast<int32_t>(len))
    throw std::runtime_error{"EVP_EncryptUpdate: unexpected length"};

  return ciphertext;
}

template <>
std::tuple<std::vector<uint8_t>, std::vector<uint8_t>>
impl::encrypt_finish(impl::encrypt_ctx &ctx) {
  // Need to get tag
  int32_t ciphertext_len{};
  if (EVP_EncryptFinal_ex(&*ctx, nullptr, &ciphertext_len) != 1)
    throw openssl_error("EVP_EncryptFinal");
  // No additional output in stream cipher
  if (ciphertext_len != 0)
    throw std::runtime_error{"EVP_EncryptFinal: unexpected length"};

  assert(kAesTagSize == EVP_CIPHER_CTX_get_tag_length(&*ctx));

  std::vector<uint8_t> tag(kAesTagSize);
  if (EVP_CIPHER_CTX_ctrl(&*ctx, EVP_CTRL_AEAD_GET_TAG, kAesTagSize,
                          tag.data()) != 1)
    throw openssl_error("EVP_CIPHER_CTX_ctrl(GET_TAG)");

  return {{}, std::move(tag)};
}

template <>
impl::decrypt_ctx impl::decrypt_init(impl::sym_key &secret, uint8_t const *iv,
                                     uint8_t const *tag) {
  auto cipher_mode = EVP_aes_256_gcm();

  EVP_CIPHER_CTX_ptr cipher_ctx{EVP_CIPHER_CTX_new()};

  if (EVP_DecryptInit_ex2(&*cipher_ctx, cipher_mode, secret.data(), iv,
                          nullptr) != 1)
    throw openssl_error("EVP_DecryptInit");

  if (EVP_CIPHER_CTX_ctrl(&*cipher_ctx, EVP_CTRL_AEAD_SET_TAG, kAesTagSize,
                          const_cast<uint8_t *>(tag)) != 1)
    throw openssl_error("EVP_CIPHER_CTX_ctrl(SET_TAG)");

  return cipher_ctx;
}

template <>
std::vector<uint8_t> impl::decrypt_update(impl::decrypt_ctx &ctx,
                                          uint8_t const *ciphertext,
                                          size_t len) {
  std::vector<uint8_t> plaintext(len);
  int32_t plaintext_len{};
  if (EVP_DecryptUpdate(&*ctx, plaintext.data(), &plaintext_len, ciphertext,
                        static_cast<int32_t>(len)) != 1)
    throw openssl_error("EVP_DecryptUpdate");
  // Ciphertext is of the same length with plaintext in stream cipher
  if (plaintext_len != static_cast<int32_t>(len))
    throw std::runtime_error{"EVP_DecryptUpdate: unexpected length"};

  return plaintext;
}

template <>
std::vector<uint8_t> impl::decrypt_finish(impl::decrypt_ctx &ctx) {
  // Need to verify tag
  int32_t plaintext_len{};
  if (EVP_DecryptFinal_ex(&*ctx, nullptr, &plaintext_len) != 1)
    throw openssl_error("EVP_DecryptFinal");
  // No additional output in stream cipher
  if (plaintext_len != 0)
    throw std::runtime_error{"EVP_DecryptFinal: unexpected length"};

  return {};
}

} // namespace detail
} // namespace jsc
