#include <jsc/detail/encoding.hpp>

#include <cstddef> // For size_t
#include <cstdint>
#include <string>
#include <tuple>
#include <vector>

namespace jsc {
namespace detail {

// https://stackoverflow.com/a/41094722
static constexpr char kBase64Chars[] =
    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

std::string base64_encode(uint8_t const *in, size_t len) {
  std::string output;
  output.resize((len + 2) / 3 * 4);

  auto end = in + len;
  auto out = output.begin();
  while (end - in > 2) {
    out[0] = kBase64Chars[in[0] >> 2];
    out[1] = kBase64Chars[((in[0] & 0x03) << 4) | (in[1] >> 4)];
    out[2] = kBase64Chars[((in[1] & 0x0f) << 2) | (in[2] >> 6)];
    out[3] = kBase64Chars[in[2] & 0x3f];
    in += 3;
    out += 4;
  }

  // Padding
  if (in < end) {
    out[0] = kBase64Chars[in[0] >> 2];
    if (end - in > 1) {
      out[1] = kBase64Chars[((in[0] & 0x03) << 4) | (in[1] >> 4)];
      out[2] = kBase64Chars[(in[1] & 0x0f) << 2];
    } else {
      out[1] = kBase64Chars[(in[0] & 0x03) << 4];
      out[2] = '=';
    }
    out[3] = '=';
  }

  return output;
}

static constexpr uint8_t kBase64Reverse[256]{
    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
    0,  0,  0,  0,  0,  62, 0,  62, 0,  63, 52, 53, 54, 55, 56, 57, 58, 59, 60,
    61, 0,  0,  0,  0,  0,  0,  0,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10,
    11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0,  0,  0,  0,
    63, 0,  26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
    43, 44, 45, 46, 47, 48, 49, 50, 51, 0,  0,  0,  0,  63};

/** @returns in, end, output size */
static std::tuple<uint8_t const *, uint8_t const *, size_t>
base64_decode_prepare(std::string_view input) {
  auto in = reinterpret_cast<uint8_t const *>(input.data()),
       end = in + input.size();
  if (end - in >= 1 && end[-1] == '=') {
    --end;
    if (end - in >= 1 && end[-1] == '=') {
      --end;
    }
  }
  return {in, end, (end - in) * 3 / 4};
}

static void base64_decode_into(uint8_t const *in, uint8_t const *end,
                               uint8_t *out) {
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wimplicit-int-conversion"
#pragma clang diagnostic ignored "-Wsign-conversion"
#endif

  while (end - in > 3) {
    uint32_t n = kBase64Reverse[in[0]] << 18 | kBase64Reverse[in[1]] << 12 |
                 kBase64Reverse[in[2]] << 6 | kBase64Reverse[in[3]];
    out[0] = n >> 16;
    out[1] = n >> 8 & 0xff;
    out[2] = n & 0xff;
    in += 4;
    out += 3;
  }

  // Padding
  if (end - in > 1) {
    uint32_t n = kBase64Reverse[in[0]] << 18 | kBase64Reverse[in[1]] << 12;
    out[0] = n >> 16;
    if (end - in > 2) {
      n |= kBase64Reverse[in[2]] << 6;
      out[1] = n >> 8 & 0xff;
    }
  }
#ifdef __clang__
#pragma clang diagnostic pop
#endif
}

std::vector<uint8_t> base64_decode(std::string_view input) {
  auto [in, end, output_size] = base64_decode_prepare(input);
  std::vector<uint8_t> output(output_size);

  base64_decode_into(in, end, output.data());
  return output;
}

std::string base64_decode_string(std::string_view input) {
  auto [in, end, output_size] = base64_decode_prepare(input);
  std::string output;
  output.resize(output_size);

  base64_decode_into(in, end, reinterpret_cast<uint8_t *>(output.data()));
  return output;
}

static constexpr char kHexDigits[] = "0123456789abcdef";

std::string hex_encode(uint8_t const *in, size_t len) {
  std::string output;
  output.resize(len * 2);

  auto end = in + len;
  auto out = output.begin();
  while (in < end) {
    out[0] = kHexDigits[*in >> 4];
    out[1] = kHexDigits[*in & 0xf];
    in += 1;
    out += 2;
  }
  return output;
}

} // namespace detail
} // namespace jsc
