#include <jsc/client_key.hpp>

#include <cstddef> // For size_t
#include <istream>
#include <ostream>
#include <stdexcept>
#include <string_view>
#include <vector>

#include <nlohmann/json.hpp>

#include <jsc/detail/encoding.hpp>

namespace jsc {

using namespace detail;

template <typename T>
static void append_vector(std::vector<T> &to, std::vector<T> const &from) {
  to.insert(to.end(), from.cbegin(), from.cend());
}

std::vector<uint8_t> ResponseKey::decrypt(std::string_view response) {
  std::vector<uint8_t> iv, tag, ciphertext;
  try {
    auto json = nlohmann::json::parse(response);
    iv = base64_decode(json.at("nonce").get<std::string_view>());
    tag = base64_decode(json.at("mac").get<std::string_view>());
    ciphertext = base64_decode(json.at("ciphertext").get<std::string_view>());
  } catch (nlohmann::json::exception const &e) {
    throw ParamError{"metadata", e.what()};
  }

  auto decryptor = key.decrypt(iv, tag);
  auto plaintext = decryptor.decrypt(ciphertext.data(), ciphertext.size());
  auto remaining = decryptor.finish();
  append_vector(plaintext, remaining);

  return plaintext;
}

static constexpr size_t kBlockSize = 4096;

void ResponseKey::decrypt(std::string_view metadata, std::istream &source,
                          std::ostream &dest) {
  std::vector<uint8_t> iv, tag;
  try {
    auto json = nlohmann::json::parse(metadata);
    iv = base64_decode(json.at("nonce").get<std::string_view>());
    tag = base64_decode(json.at("mac").get<std::string_view>());
  } catch (nlohmann::json::exception const &e) {
    throw ParamError{"metadata", e.what()};
  }

  auto decryptor = key.decrypt(iv, tag);
  std::vector<uint8_t> input(kBlockSize);
  std::vector<uint8_t> output;

  while (source) {
    source.read(reinterpret_cast<char *>(input.data()),
                static_cast<std::streamsize>(input.size()));
    std::streamsize bytes_read = source.gcount();

    if (!bytes_read)
      break;

    output = decryptor.decrypt(input.data(), static_cast<size_t>(bytes_read));

    dest.write(reinterpret_cast<char const *>(output.data()),
               static_cast<std::streamsize>(output.size()));
    if (!dest)
      throw std::runtime_error("Failed to write to output stream");
  }

  // Remaining
  output = decryptor.finish();
  dest.write(reinterpret_cast<char const *>(output.data()),
             static_cast<std::streamsize>(output.size()));
  if (!dest)
    throw std::runtime_error("Failed to write to output stream");
}

namespace detail {

EncryptResult SessionKey::encrypt(uint8_t const *data, size_t len) {
  auto [encrypted_key, iv, encryptor, decrypt_key] = key.encrypt();
  std::vector<uint8_t> output;

  auto ciphertext = encryptor.encrypt(data, len);

  auto [remaining, tag] = encryptor.finish();
  append_vector(ciphertext, remaining);

  nlohmann::json metadata{
      {"key", base64_encode(encrypted_key)},
      {"nonce", base64_encode(iv)},
      {"ciphertext", base64_encode(ciphertext)},
      {"mac", base64_encode(tag)},
  };
  return {metadata.dump(), {std::move(decrypt_key)}};
}

EncryptResult SessionKey::encrypt(std::string_view data) {
  return encrypt(reinterpret_cast<uint8_t const *>(data.data()), data.size());
}

EncryptResult SessionKey::encrypt(std::istream &source, std::ostream &dest) {
  auto [encrypted_key, iv, encryptor, decrypt_key] = key.encrypt();
  std::vector<uint8_t> input(kBlockSize);
  std::vector<uint8_t> output;

  while (source) {
    source.read(reinterpret_cast<char *>(input.data()),
                static_cast<std::streamsize>(input.size()));
    std::streamsize bytes_read = source.gcount();

    if (!bytes_read)
      break;

    output = encryptor.encrypt(input.data(), static_cast<size_t>(bytes_read));

    dest.write(reinterpret_cast<char const *>(output.data()),
               static_cast<std::streamsize>(output.size()));

    if (!dest)
      throw std::runtime_error("Failed to write to output stream");
  }

  auto [remaining, tag] = encryptor.finish();
  dest.write(reinterpret_cast<char const *>(remaining.data()),
             static_cast<std::streamsize>(remaining.size()));
  if (!dest)
    throw std::runtime_error("Failed to write to output stream");

  nlohmann::json metadata{
      {"key", base64_encode(encrypted_key)},
      {"nonce", base64_encode(iv)},
      {"mac", base64_encode(tag)},
  };
  return {metadata.dump(), {std::move(decrypt_key)}};
}

} // namespace detail

} // namespace jsc
