#include <jsc/vbs/host_key.hpp>

#include <cstddef> // For size_t
#include <cstdint>
#include <string_view>
#include <vector>

#include <jsc/error.hpp>
#include <jsc/vbs/host.hpp>
#include <jsc/vbs/interface.hpp>

#ifndef JSC_CRYPTO_OPENSSL
#include <jsc/detail/openssl_utils.hpp>
#endif

namespace jsc {
namespace detail {

EnclaveObject::~EnclaveObject() noexcept {
  if (!enclave || !id)
    return;

  try {
    enclave->call<vbs::free_object>(id);
  } catch (...) {
    // Ignored
  }
}

std::vector<uint8_t> Decryptor::decrypt(uint8_t const *ciphertext, size_t len) {
  if (!enclave || !id)
    throw StateError{"enclave key not set"};
  if (!len)
    return {};

  auto plaintext_len = len + kAesBlockSize - 1;
  std::vector<uint8_t> plaintext(plaintext_len);
  enclave->call<vbs::decrypt_update>(id, ciphertext, len, plaintext.data(),
                                     &plaintext_len);
  if (plaintext_len > plaintext.size())
    throw CryptoError{"unexpected length"};
  plaintext.resize(plaintext_len);

  return plaintext;
}

std::vector<uint8_t> Decryptor::finish() {
  if (!enclave || !id)
    throw StateError{"enclave key not set"};

  auto plaintext_len = kAesBlockSize - 1;
  std::vector<uint8_t> plaintext(plaintext_len);
  enclave->call<vbs::decrypt_finish>(id, plaintext.data(), &plaintext_len);
  if (plaintext_len > plaintext.size())
    throw CryptoError{"unexpected length"};
  plaintext.resize(plaintext_len);

  return plaintext;
}

Decryptor DecryptingKey::decrypt(std::vector<uint8_t> const &iv,
                                 std::vector<uint8_t> const &tag) {
  if (!enclave || !id)
    throw StateError{"enclave key not set"};
  if (iv.size() != kAesIvSize)
    throw ParamError{"iv", "Wrong length"};
  if (tag.size() != kAesTagSize)
    throw ParamError{"tag", "Wrong length"};

  auto decryptor_id =
      enclave->call<vbs::decrypt_init>(id, iv.data(), tag.data());

  return {enclave, decryptor_id};
}

std::vector<uint8_t> Encryptor::encrypt(uint8_t const *plaintext, size_t len) {
  if (!enclave || !id)
    throw StateError{"enclave key not set"};
  if (!len)
    return {};

  auto ciphertext_len = len + kAesBlockSize - 1;
  std::vector<uint8_t> ciphertext(ciphertext_len);
  enclave->call<vbs::encrypt_update>(id, plaintext, len, ciphertext.data(),
                                     &ciphertext_len);
  if (ciphertext_len > ciphertext.size())
    throw CryptoError{"unexpected length"};
  ciphertext.resize(ciphertext_len);

  return ciphertext;
}

EncryptFinish Encryptor::finish() {
  if (!enclave || !id)
    throw StateError{"enclave key not set"};

  auto ciphertext_len = kAesBlockSize - 1;
  std::vector<uint8_t> ciphertext(ciphertext_len);
  std::vector<uint8_t> tag(kAesTagSize);
  enclave->call<vbs::encrypt_finish>(id, ciphertext.data(), &ciphertext_len,
                                     tag.data());
  if (ciphertext_len > ciphertext.size())
    throw CryptoError{"unexpected length"};
  ciphertext.resize(ciphertext_len);

  return {std::move(ciphertext), std::move(tag)};
}

EncryptInit EncryptingKey::encrypt() {
  if (!enclave || !id)
    throw StateError{"enclave key not set"};

  auto enc_key_len = kEncryptedKeySize;
  std::vector<uint8_t> enc_key(enc_key_len);
  std::vector<uint8_t> iv(kAesIvSize);

  vbs::object_id_t sym_key_id;
  auto encryptor_id = enclave->call<vbs::encrypt_init>(
      id, enc_key.data(), &enc_key_len, iv.data(), &sym_key_id);

  Encryptor encryptor{enclave, encryptor_id};
  DecryptingKey decrypt_key{enclave, sym_key_id};

  if (enc_key_len > kEncryptedKeySize)
    throw CryptoError{"unexpected length"};
  enc_key.resize(enc_key_len);

  return {.encrypted_key{std::move(enc_key)},
          .iv{std::move(iv)},
          .encryptor{std::move(encryptor)},
          .decrypt_key{std::move(decrypt_key)}};
}

std::vector<uint8_t> SigningKey::sign(uint8_t const *data, size_t len) {
  if (!enclave || !id)
    throw StateError{"enclave key not set"};

  auto signature_len = kSignatureSize;
  std::vector<uint8_t> signature(signature_len);

  enclave->call<vbs::rsa_sign>(id, data, len, signature.data(), &signature_len);
  if (signature_len != kSignatureSize)
    throw CryptoError{"unexpected length"};

  return signature;
}

std::vector<uint8_t> SigningKey::sign(std::string_view data) {
  return sign(reinterpret_cast<uint8_t const *>(data.data()), data.size());
}

EncryptingKey EnclaveKeyStore::load_rsa_public_key(std::string_view pem) {
#ifdef JSC_CRYPTO_OPENSSL
  auto key_id = enclave.call<vbs::load_rsa_public_key>(pem.data(), pem.size());
#else
  auto blob = rsa_public_key_to_blob(*parse_rsa_public_key(pem));
  auto key_id = enclave.call<vbs::load_rsa_public_key>(
      reinterpret_cast<char const *>(blob.data()), blob.size());
#endif
  return {&enclave, key_id};
}

SigningKey EnclaveKeyStore::load_rsa_private_key(std::string_view blob) {
  auto key_id =
      enclave.call<vbs::load_rsa_private_key>(blob.data(), blob.size());
  return {&enclave, key_id};
}

} // namespace detail
} // namespace jsc
