import hashlib
import os
import tarfile
from dataclasses import dataclass
from pathlib import Path
from time import time
from multiprocessing import Pool
import concurrent.futures
from functools import reduce

import shutil
import requests
import tos
from tos.utils import SizeAdapter, MergeProcess
from tos import DataTransferType
from typing_extensions import TYPE_CHECKING, Dict, List, Optional
from bytedance.tks.client import TKSClient, TKSConfig, utc8_timestamp
from bytedance.tks.crypto.aes import gcm_decrypt, gcm_encrypt

if TYPE_CHECKING:
    from _typeshed import GenericPath
else:
    GenericPath = str


def upload_files(args):
    bucket_name = args["bucket_name"]
    path = args["path"]
    volc_ak = args["volc_ak"]
    volc_sk = args["volc_sk"]
    region = args["region"]
    endpoint = args["endpoint"]

    tos_client = tos.TosClientV2(volc_ak, volc_sk, endpoint, region)

    total_size = os.path.getsize(path)
    part_size = 5 * 1024 * 1024

    def percentage(consumed_bytes: int, total_bytes: int, rw_once_bytes: int, type: DataTransferType):
        if total_bytes:
            rate = int(100 * float(consumed_bytes) / float(total_bytes))
            if rw_once_bytes != 0 and ((consumed_bytes // rw_once_bytes) % 1000 == 0):
                print(
                    "        [{}] rate:{}, consumed_bytes:{},total_bytes{}, rw_once_bytes:{}, type:{}".format(
                    path, rate, consumed_bytes, total_bytes, rw_once_bytes, type))


    # 配置进度条，与普通上传不同的是需将分片上传的进度聚合
    data_transfer_listener = MergeProcess(percentage, total_size, (total_size + part_size - 1) // part_size, 0)
    # 初始化上传任务
    # 若需在初始化分片时设置对象的存储类型，可通过storage_class字段设置
    # 若需在初始化分片时设置对象ACL，可通过acl、grant_full_control等字段设置
    multi_result = tos_client.create_multipart_upload(bucket_name, path, acl=tos.ACLType.ACL_Public_Read,
                                                storage_class=tos.StorageClassType.Storage_Class_Standard)
    
    upload_id = multi_result.upload_id
    parts = []

    # 上传分片数据
    with open(path, 'rb') as f:
        part_number = 1
        offset = 0
        while offset < total_size:
            num_to_upload = min(part_size, total_size - offset)
            out = tos_client.upload_part(bucket_name, path, upload_id, part_number,
                                    content=SizeAdapter(f, num_to_upload, init_offset=offset),
                                    data_transfer_listener=data_transfer_listener)
            parts.append(out)
            offset += num_to_upload
            part_number += 1

    # 完成分片上传任务
    tos_client.complete_multipart_upload(bucket_name, path, upload_id, parts) 


def _compress_model_dir(model_dir: GenericPath) -> Path:
    compressed_file_name = os.path.basename(model_dir) + ".tar"
    compressed_file_path = Path(compressed_file_name)
    with tarfile.open(compressed_file_path, "w") as tar:
        tar.add(model_dir, arcname=os.path.basename(model_dir))
    return compressed_file_path


def _decompress_model_dir(compressed_file: GenericPath, output_path: GenericPath) -> None:
    with tarfile.open(compressed_file, "r") as tar:
        tar.extractall(path=output_path)


DEFAULT_PARTITION_SIZE = 2 * 1024 * 1024 * 1024


def _partition(model_file: GenericPath, partition_size=DEFAULT_PARTITION_SIZE) -> List[bytes]:
    with open(model_file, "rb") as fin:
        plaintext_bytes = fin.read()
    chunks: List[bytes] = []
    offset = 0
    while offset < len(plaintext_bytes):
        chunks.append(plaintext_bytes[offset : offset + partition_size])  # noqa: E203
        offset += partition_size
    return chunks


def _decrypt_chunk(params):
    enc_chunk_path = params["enc_chunk_path"]
    key = params["key"]
    t0 = time()
    with open(enc_chunk_path, "rb") as fin:
        enc_chunk = fin.read()
    chunk = gcm_decrypt(key, enc_chunk)
    print(f"Decrypt chunk: {enc_chunk_path} done, costs {time() - t0} seconds.")
    return chunk


def md5_hash(data: bytes, salt: str = "") -> str:
    return hashlib.md5(data + salt.encode()).hexdigest()


def copy_folder(source_folder, destination_folder):
    """
    Copy the entire contents of a folder to a new location.
    
    :param source_folder: Path to the folder to copy.
    :param destination_folder: Path to the destination folder.
    """
    try:
        # Check if source folder exists
        if not os.path.exists(source_folder):
            print(f"Source folder '{source_folder}' does not exist.")
            return

        # Copy folder
        shutil.copytree(source_folder, destination_folder)
        print(f"        Folder copied from '{source_folder}' to '{destination_folder}'.")

    except FileExistsError:
        print(f"        Destination folder '{destination_folder}' already exists.")
    except PermissionError:
        print("        Permission denied: Unable to copy folder.")
    except Exception as e:
        print(f"        An error occurred: {e}")


@dataclass
class EncryptionConfig:
    app_id: str

    password: str

    ring_id: Optional[str] = None

    key_id: Optional[str] = None


class JeddakModelEncrypter:
    def __init__(self, encryption_config: EncryptionConfig = None):
        if encryption_config is not None:
            self.app_id = encryption_config.app_id
            self.password = encryption_config.password
            self.ring_id = encryption_config.ring_id
            self.key_id = encryption_config.key_id

            tks_config = TKSConfig(enable_tls=True)
            self.tks_client = TKSClient(self.app_id, tks_config)

        self.login_url = "https://jeddakchain.bytedance.com/api/user/login"

    def encrypt_model(self, model_path: GenericPath, output_path: GenericPath) -> bytes:
        os.makedirs(output_path, exist_ok=True)

        data_key = os.urandom(32)

        print("        Start compressing model files.")
        # compress if the `model_path` is directory
        t0 = time()
        if os.path.isdir(model_path):
            model_file = _compress_model_dir(model_path)
        else:
            model_file = model_path
        print(f"        Compress done, costs {time() - t0} seconds.")

        chunks = _partition(model_file)
        print(f"        Total {len(chunks)} chunks.")

        t0 = time()
        for i, chunk in enumerate(chunks):
            enc_chunk = gcm_encrypt(data_key, chunk)
            with open(os.path.join(output_path, f"enc-part-{i:05}"), "wb") as fout:
                fout.write(enc_chunk)
        print(f"        Encrypt done, costs {time() - t0} seconds.")

        return data_key

    def decrypt_model(
        self, enc_model_path: GenericPath, key: bytes, output_path: GenericPath, processes=8
    ) -> None:
        os.makedirs(output_path, exist_ok=True)
        pool = Pool(processes)
        t0 = time()
        plaintext_bytes = bytearray()
        if os.path.isdir(enc_model_path):
            print(f"Start decrypt enc_model_path: {enc_model_path}")
            params_list = []
            for filename in sorted(os.listdir(enc_model_path)):
                params_list.append({"enc_chunk_path": os.path.join(enc_model_path, filename), "key": key})
            chunks = pool.map(_decrypt_chunk, params_list)
            print(f"Pool map done, costs {time() - t0} seconds.")
            for chunk in chunks:
                plaintext_bytes += chunk
        else:
            with open(enc_model_path, "rb") as fin:
                enc_chunk = fin.read()
            chunk = gcm_decrypt(key, enc_chunk)
            plaintext_bytes += chunk
        plaintext_bytes = bytes(plaintext_bytes)
        print(f"Decrypt done, costs {time() - t0} seconds.")

        plaintext_compressed_file = "./decrypted_model.tar"
        with open(plaintext_compressed_file, "wb") as fout:
            fout.write(plaintext_bytes)        

        t0 = time()
        _decompress_model_dir(plaintext_compressed_file, output_path)
        print(f"Decompress done, costs {time() - t0} seconds.")

    # def compute_model_baseline(self, model_dir: GenericPath) -> str:
    #     baseline = b""
    #     file_path_list = []
    #     for dirpath, _, filenames in os.walk(model_dir):
    #         for filename in filenames:
    #             file_path = os.path.join(dirpath, filename)
    #             file_path_list.append(file_path)

    #     for file_path in sorted(file_path_list):
    #         with open(file_path, "rb") as fin:
    #             file_bytes = fin.read()
    #             baseline = hashlib.sha256(baseline + file_bytes).digest()

    #     return baseline.hex()

    def compute_file_hash(self, file_path: str) -> bytes:
        with open(file_path, "rb") as fin:
            return hashlib.sha256(fin.read()).digest()

    def compute_model_baseline(self, model_dir: str) -> str:
        file_path_list: List[str] = []

        for dirpath, _, filenames in os.walk(model_dir):
            for filename in filenames:
                file_path_list.append(os.path.join(dirpath, filename))

        file_path_list.sort()

        num_threads=os.cpu_count()
        print(f"        Using {num_threads} threads")

        with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
            file_hashes = list(executor.map(self.compute_file_hash, file_path_list))
        baseline = reduce(lambda acc, x: hashlib.sha256(acc + x).digest(), file_hashes, b"")

        return baseline.hex()

    def import_key(self, ring_id: str, key: bytes, key_name: str, desc="") -> None:
        # login
        ts = str(utc8_timestamp())
        req_body = {"Username": self.app_id, "Password": md5_hash(self.password.encode(), ts)}

        response = requests.post(self.login_url, headers={"TimeStamp": ts}, json=req_body)
        cookies = response.cookies
        jeddak_pcc_session = cookies.get("jeddak_pcc_session")
        cookies = {"jeddak_pcc_session": jeddak_pcc_session,
                   "usage_scenario": "ModelEncryption"}

        self.tks_client.import_key(
            ring_id=ring_id,
            algo="SYMMETRIC_256",
            key=key,
            key_name=key_name,
            desc=desc,
            cookies=cookies,
        )

    def get_key(self, ring_id: str, key_id: str) -> bytes:
        cookies = self._pcc_login()
        key = self.tks_client.get_key(ring_id, key_id, cookies=cookies)
        return key

    def upload_model_to_tos(self, model_path: str, bucket_name: str, volc_ak: str, volc_sk: str, region: str, endpoint: str, processes: int = 10):
        try:
            pool = Pool(processes)

            def upload_dir(root_dir):
                files = [str(file) for file in Path(root_dir).rglob('*') if file.is_file()]
                param_list = []
                for file in files:
                    params = {
                        "bucket_name": bucket_name,
                        "path": file,
                        "volc_ak": volc_ak,
                        "volc_sk": volc_sk,
                        "region": region,
                        "endpoint": endpoint
                    }
                    param_list.append(params)
                pool.map(upload_files, param_list)
                pool.close()
                pool.join()

            upload_dir(model_path)
        except tos.exceptions.TosClientError as e:
            # 操作失败，捕获客户端异常，一般情况为非法请求参数或网络异常
            print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause))
        except tos.exceptions.TosServerError as e:
            # 操作失败，捕获服务端异常，可从返回信息中获取详细错误信息
            print('fail with server error, code: {}'.format(e.code))
            # request id 可定位具体问题，强烈建议日志中保存
            print('error with request id: {}'.format(e.request_id))
            print('error with message: {}'.format(e.message))
            print('error with http code: {}'.format(e.status_code))
            print('error with ec: {}'.format(e.ec))
            print('error with request url: {}'.format(e.request_url))
        except Exception as e:
            print('fail with unknown error: {}'.format(e))

    def encrypt_model_and_upload(self, 
        model_path: str, bucket_name: str, volc_ak: str, volc_sk, region: str, endpoint: str, app_id: str, top_ak: str = "", top_sk: str = "",
        processes: int = 10, ring_id: str = "", ring_name: str = "", ring_desc: str = "", key_name: str = "", key_desc: str = "", service: str = "pcc",
        tks_addr: str = "open.volcengineapi.com", encrypt_flag: bool = True
    ) -> dict:
        # Example usage
        if encrypt_flag:
            print("    Step1.1: model needs encryption")
            top_conf = {
                "ak": volc_ak,
                "sk": volc_sk,
                "service": service,
                "region": region,
            }

            tks_config = TKSConfig(addr=tks_addr, enable_tls=True, top_config=top_conf)
            tks_client = TKSClient(app_id=app_id, config=tks_config)

            step = 1
            if not ring_id:
                print("    Step1.2: ring_id is not provided, start creating.")
                assert ring_name, "ring_name should provided when create"  
                ring_id = tks_client.create_ring(ring_name=ring_name, desc=ring_desc)["RingID"]
                print("    Step1.3: create ring successfully.")
                step = 3

            model_name = os.path.basename(model_path)
            enc_output_path = "enc-" + model_name
            data_key = self.encrypt_model(model_path, output_path=enc_output_path)

            print("    Step1.{}: start import key to tks...".format(1+step))
            assert key_name, "key_name should provided"
            key_id = tks_client.import_key(ring_id=ring_id, algo="SYMMETRIC_256", key=data_key, key_name=key_name, desc=key_desc, usage_scenario="ModelEncryption")["KeyID"]
            print("    Step1.{}: import key to tks successfully.".format(2+step))

            print("    Step1.{}: start compute model baseline".format(3+step))
            baseline = self.compute_model_baseline(model_path)
            print("    Step1.{}: compute model baseline done.".format(4+step))

            print("    Step1.{}: start upload enc model to tos...".format(5+step))
            t0 = time()
            self.upload_model_to_tos(model_path=enc_output_path + "/", bucket_name=bucket_name, volc_ak=volc_ak, volc_sk=volc_sk, region=region, endpoint=endpoint, processes=processes)
            print("    Step1.{}: upload to tos done, costs: {} seconds.".format(6+step, time() - t0))

            if region == "e28-env":
                tos_url = f"https://{bucket_name}.tos-vpc-{region}.e28.inspirecloud.io/{enc_output_path}/"
            elif region == "neimeng-1":
                tos_url = f"https://{bucket_name}.tos-vpc-neimeng-1.pcc.lenovo.com/{enc_output_path}/"
            else:
                tos_url = f"https://{bucket_name}.tos-{region}.ivolces.com/{enc_output_path}/"
            return {
                "ring_id": ring_id,
                "key_id": key_id,
                "baseline": baseline,
                "model_name": model_name,
                "tos_url": tos_url
            }
        else:
            print("    Step1.1: model doesn't need encryption")
            source = model_path
            model_name = os.path.basename(model_path)
            destination = "plain-" + model_name
            copy_folder(source, destination)
            model_name = os.path.basename(model_path)
            print("    Step1.2: start compute model baseline")
            baseline = self.compute_model_baseline(model_path)
            print("    Step1.3: compute model baseline done.")
            print("    Step1.4: start upload model to tos...")
            t0 = time()

            self.upload_model_to_tos(model_path=destination + "/", bucket_name=bucket_name, volc_ak=volc_ak, volc_sk=volc_sk, region=region, endpoint=endpoint, processes=processes)
            print(f"    Step1.5: upload to tos done, costs: {time() - t0} seconds.")
            if region == "e28-env":
                tos_url = f"https://{bucket_name}.tos-vpc-{region}.e28.inspirecloud.io/{destination}/"
            elif region == "neimeng-1":
                tos_url = f"https://{bucket_name}.tos-vpc-neimeng-1.pcc.lenovo.com/{destination}/"
            else:
                tos_url = f"https://{bucket_name}.tos-{region}.ivolces.com/{destination}/"
            return {
                "baseline": baseline,
                "model_name": model_name,
                "tos_url": tos_url
            }
    def _pcc_login(self) -> Dict[str, str]:
        ts = str(utc8_timestamp())
        req_body = {"Username": self.app_id, "Password": md5_hash(self.password.encode(), ts)}

        response = requests.post(self.login_url, headers={"TimeStamp": ts}, json=req_body)
        cookies = response.cookies
        jeddak_pcc_session = cookies.get("jeddak_pcc_session")
        assert jeddak_pcc_session is not None
        cookies = {"jeddak_pcc_session": jeddak_pcc_session}
        return cookies
