from bytedance.jeddak_secure_model.model_encryption import EncryptionConfig, JeddakModelEncrypter
from bytedance.jeddak_secure_model.rds_utils import register_to_rds
from huggingface_hub import snapshot_download
import argparse
import os
import shutil
from pathlib import Path

parser = argparse.ArgumentParser(description="An useful tool to host a new model for adminstrator.")
parser.add_argument("--huggingface_model", type=str, default="Qwen/Qwen2.5-0.5B", help="the huggingface modelname")
parser.add_argument("--rds_url", type=str, default="", help="the rds url used for storing model table")
parser.add_argument("--bucket_name", type=str, default="jeddakpcc-platform-hosted-models", help="the bucket_name of tos")
parser.add_argument("--volc_ak", type=str, default="", help="the bucket_name of tos")
parser.add_argument("--volc_sk", type=str, default="", help="the bucket_name of tos")
parser.add_argument("--region", type=str, default="cn-beijing", help="the bucket_name of tos")
parser.add_argument("--endpoint", type=str, default="tos-cn-beijing.volces.com", help="the bucket_name of tos")
args = parser.parse_args()

model_name = args.huggingface_model
rds_url = args.rds_url
bucket_name = args.bucket_name
volc_ak = args.volc_ak
volc_sk = args.volc_sk
region = args.region
endpoint = args.endpoint
assert rds_url != "", "rds_url is missing"
assert bucket_name != "", "bucket_name is missing"
assert volc_ak != "", "volc_ak is missing"
assert volc_sk != "", "volc_sk is missing"
assert region != "", "region is missing"
assert endpoint != "", "endpoint is missing"
app_id = ""
top_ak = ""
top_sk = ""
ring_id = ""
ring_name = ""
ring_desc = ""
key_name = ""
key_desc = ""
service = ""


def replace_symlinks_with_copies(model_name, directory):
    folder_path = Path(model_name)
    folder_path.mkdir(parents=True, exist_ok=True)
    print(f"making dir for {model_name}")
    for root, _, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)
            if os.path.islink(file_path):  # 检测是否是符号链接
                real_path = os.readlink(file_path)
                abs_file_path = Path(file_path)
                real_relative_path = Path(real_path)
                abs_real_path = (abs_file_path.parent / real_relative_path).resolve()
                model_file_path = folder_path / file_path.split("/")[-1]
                print(f"copying {abs_real_path} to {model_file_path}")
                shutil.copy(abs_real_path, model_file_path)  # 复制实际文件
                os.unlink(file_path)  # 删除符号链接
    tmp_root_path = Path(directory).parts[0]
    print(f"removing {tmp_root_path}")
    shutil.rmtree(tmp_root_path)


local_dir = snapshot_download(repo_id=model_name, cache_dir="./")
# replace_symlinks_with_copies(model_name, local_dir)

old=Path(local_dir)
new=old.parent / model_name.split("/")[-1]
try:
    shutil.rmtree(new)
except Exception as e:
    print(e)
old.rename(new)

model_path = str(new)  # 模型的文件夹路径


config = EncryptionConfig("", "")
encrypter = JeddakModelEncrypter(config)

res = encrypter.encrypt_model_and_upload(
    model_path=model_path,
    bucket_name=bucket_name,
    volc_ak=volc_ak,
    volc_sk=volc_sk,
    region=region,
    endpoint=endpoint,
    top_ak=top_ak,
    top_sk=top_sk,
    ring_id=ring_id,
    ring_name=ring_name,
    ring_desc=ring_desc,
    key_name=key_name,
    key_desc=key_desc,
    app_id=app_id,
    service=service,
    encrypt_flag=False
)

model = model_name.split("/")[-1]
tos_url = f"https://{bucket_name}.{endpoint}/plain-{model}/".replace("volces", "ivolces")
baseline = res["baseline"]
model_meta_data = {
    "model_name": model_name,
    "tos_url": tos_url,
    "baseline": baseline
}

register_to_rds(rds_url, model_meta_data)