#include <jsc/detail/openssl_utils.hpp>

#include <cstdint>
#include <cstring>   // For memcpy
#include <stdexcept> // For runtime_error
#include <string>
#include <string_view>
#include <vector>

#include <openssl/bio.h>
#include <openssl/bn.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/pem.h>

namespace jsc {
namespace detail {

std::runtime_error openssl_error(char const *origin) {
  return std::runtime_error{std::string{origin} + ": " +
                            ERR_error_string(ERR_get_error(), nullptr)};
}

EVP_PKEY_ptr parse_rsa_public_key(std::string_view pem) {
  BIO_ptr source{BIO_new_mem_buf(pem.data(), static_cast<int>(pem.size()))};
  if (!source)
    throw openssl_error("BIO_new_mem_buf");

  EVP_PKEY_ptr pkey{PEM_read_bio_PUBKEY(&*source, nullptr, nullptr, nullptr)};
  if (!pkey)
    throw openssl_error("PEM_read_bio_PUBKEY");

  if (!EVP_PKEY_is_a(&*pkey, "RSA"))
    throw std::runtime_error{"Key is not RSA"};

  return pkey;
}

std::vector<uint8_t> rsa_public_key_to_blob(EVP_PKEY const &pkey) {
  size_t bits{};
  if (EVP_PKEY_get_size_t_param(&pkey, "bits", &bits) != 1)
    throw openssl_error("EVP_PKEY_get_size_t_param(bits)");
  BIGNUM *e_raw{}, *n_raw{};
  if (EVP_PKEY_get_bn_param(&pkey, "e", &e_raw) != 1 || !e_raw)
    throw openssl_error("EVP_PKEY_get_bn_param(e)");
  BN_ptr e{e_raw};
  if (EVP_PKEY_get_bn_param(&pkey, "n", &n_raw) != 1 || !n_raw)
    throw openssl_error("EVP_PKEY_get_bn_param(n)");
  BN_ptr n{n_raw};

  constexpr char kMagic[4] = {'R', 'S', 'A', '1'};

  struct header_t {
    char magic[4];
    uint32_t bit_length;
    uint32_t e_len;
    uint32_t n_len;
    uint32_t p_len;
    uint32_t q_len;
  };

  auto e_len = static_cast<uint32_t>(BN_num_bytes(&*e)),
       n_len = static_cast<uint32_t>(BN_num_bytes(&*n));

  std::vector<uint8_t> buffer(sizeof(header_t) + e_len + n_len);

  auto &header = *reinterpret_cast<header_t *>(buffer.data());
  std::memcpy(&header.magic, kMagic, sizeof(kMagic));
  header.bit_length = static_cast<uint32_t>(bits);
  header.e_len = e_len;
  header.n_len = n_len;
  header.p_len = 0;
  header.q_len = 0;

  if (static_cast<uint32_t>(BN_bn2bin(&*e, &buffer[sizeof(header_t)])) != e_len)
    throw openssl_error("BN_bn2bin");
  if (static_cast<uint32_t>(
          BN_bn2bin(&*n, &buffer[sizeof(header_t) + e_len])) != n_len)
    throw openssl_error("BN_bn2bin");

  return buffer;
}

} // namespace detail
} // namespace jsc
