from bytedance.jeddak_secure_model.model_encryption import \
    EncryptionConfig, JeddakModelEncrypter
from bytedance.jeddak_secure_model import pcc_utils
import bytedance
from bytedance.jeddak_secure_model.pcc_utils import \
    set_tks_policy, \
    publish_inference_model, \
    deploy_inference_model, \
    list_inference_model, \
    test_inference_model, \
    get_engine_conf
import argparse
import time
import sys


def parser():
    parser = argparse.ArgumentParser(description="JeddakPCC Model Automation Tool")
    parser.add_argument("--model_path", type=str, default="", required=True, help="本地模型的文件夹路径")
    parser.add_argument("--bucket_name", type=str, default="", help="创建TOS的存储桶")
    parser.add_argument("--AK", type=str, default="", help="your volc access key")
    parser.add_argument("--SK", type=str, default="", help="your volc secret key")
    parser.add_argument("--region", type=str, default="cn-beijing", help="存储桶的区域")
    parser.add_argument("--endpoint", type=str, default="tos-cn-beijing.volces.com", help="tos endpoint")
    parser.add_argument("--ring_id", type=str, default="", help="密钥环ID")
    parser.add_argument("--ring_name", type=str, default="", help="密钥环名称")
    parser.add_argument("--ring_desc", type=str, default="", help="密钥环描述")
    parser.add_argument("--key_name", type=str, default="", help="密钥名称")
    parser.add_argument("--key_desc", type=str, default="", help="密钥描述")
    parser.add_argument("--app_id", type=str, default="", help="火山页面查看自己账号的app_id")
    parser.add_argument("--service", type=str, default="pcc", help="pcc for online, pcc_test for ppe")
    parser.add_argument("--encrypt_flag", '-e', action='store_true', help="是否加密模型，输入此参数表示加密，不输入此参数表示不加密")
    parser.add_argument("--tks_addr", type=str, default="open.volcengineapi.com", help="tks top 地址")
    parser.add_argument("--policy_id", type=str, default="", help="策略ID")
    parser.add_argument("--model_name", type=str, default="", help="模型名称")
    parser.add_argument("--model_version", type=str, default="", help="模型版本")
    parser.add_argument("--model_description", type=str, default="", help="模型描述")
    parser.add_argument("--cu_type", type=str, default="", help="CU类型，可选值为Basic, Advanced, Enterprise")
    parser.add_argument("--replica_count", type=int, default=1, help="副本数量，默认1")
    parser.add_argument("--inference_engine", type=str, default="vllm", help="推理引擎，可选值为vllm, sglang")
    parser.add_argument("--silent", '-s', action='store_true', help="h是否打印http request日志")
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parser()
    config = EncryptionConfig("", "")
    encrypter = JeddakModelEncrypter(config)

    # 通过 sys.modules 修改 b 中的全局变量
    sys.modules['bytedance.jeddak_secure_model.pcc_utils'].Host = args.tks_addr
    sys.modules['bytedance.jeddak_secure_model.pcc_utils'].Region = args.region

    print("Step1: Encrypting model and uploading to TOS...")
    res = encrypter.encrypt_model_and_upload(
        model_path=args.model_path,
        bucket_name=args.bucket_name,
        volc_ak=args.AK,
        volc_sk=args.SK,
        region=args.region,
        endpoint=args.endpoint,
        ring_id=args.ring_id,
        ring_name=args.ring_name,
        ring_desc=args.ring_desc,
        key_name=args.key_name,
        key_desc=args.key_desc,
        app_id=args.app_id,
        service=args.service,
        tks_addr=args.tks_addr,
        encrypt_flag=args.encrypt_flag
    )
    print(f"Step1: Finished with result={res}")

    # Step2: Set TKS policy
    # 注意：需要在JeddakPCC的控制台中创建密钥环和密钥，并获取密钥环ID和密钥ID
    # 这里的key_id是上一步中返回的密钥ID
    # policy_id是你在JeddakPCC控制台中创建的策略ID
    # 你需要在JeddakPCC控制台中创建策略，并获取策略ID
    step = 1
    if args.encrypt_flag:
        print("Step2: Setting TKS policy...")
        set_tks_policy(
            AK=args.AK,
            SK=args.SK,
            uid=args.app_id,
            key_id=res["key_id"],
            policy_id=args.policy_id,
            silent=args.silent,
            service=args.service
        )
        print("Step2: Finished setting TKS policy")
        step = 2

    print("Step{}: Publishing model to JeddakPCC MMS".format(1+step))
    ring_id = res.get("ring_id")
    key_id = res.get("key_id")
    if ring_id is None:
        ring_id = "None"
    if key_id is None:
        key_id = "None"
    publish_res = publish_inference_model(
        AK=args.AK,
        SK=args.SK,
        model_name=args.model_name,
        model_version=args.model_version,
        model_description=args.model_description,
        model_baseline=res["baseline"],
        tos_url=res["tos_url"],
        ring_id=ring_id,
        key_id=key_id,
        encrypt_flag=args.encrypt_flag,
        silent=args.silent,
        service=args.service
    )
    print("Step{}: Published model to JeddakPCC MMS with response = {}".format(1+step, publish_res))

    engine_conf_res = get_engine_conf(
        AK=args.AK,
        SK=args.SK,
        cu_type=args.cu_type,
        inference_engine=args.inference_engine,
        model_id=publish_res["Result"]["ModelID"],
        silent=args.silent,
        service=args.service
    )
    engine_conf = engine_conf_res["Result"]

    # Step3: Deploy model to JeddakPCC
    print("Step{}: Deploying model to JeddakPCC".format(2+step))
    deploy_inference_model(
        AK=args.AK,
        SK=args.SK,
        model_id=publish_res["Result"]["ModelID"],
        cu_type=args.cu_type,
        replica_count=args.replica_count,
        inference_engine=args.inference_engine,
        engine_conf=engine_conf,
        silent=args.silent,
        service=args.service
    )
    print("Step{}: Deployed model to JeddakPCC".format(2+step))

    # Step4: Query model deploy status in JeddakPCC using list_inference_model
    print("Step{}: Querying model deploy status in JeddakPCC".format(3+step))
    deploy_status = "wait"
    while True:
        if deploy_status == "running":
            break
        list_res = list_inference_model(
            AK=args.AK,
            SK=args.SK,
            silent=args.silent,
            service=args.service
        )
        model_list = list_res["Result"]["ModelList"]
        if publish_res["Result"]["ModelID"] not in [model["ModelID"] for model in model_list]:
            print("Step{}.1: Model not found in JeddakPCC".format(3+step))
            raise Exception("Model not found in JeddakPCC, checkout MMS to see if the model is published or deleted")
            break
        for model in model_list:
            if model["ModelID"] == publish_res["Result"]["ModelID"]:
                print("    Step{}.1: Queried model deploy status in JeddakPCC with status={}".format(3+step, model["DeployStatus"]))
                if model["DeployStatus"] == "running":
                    deploy_status = "running"
                    break
        time.sleep(10)

    print("Step{}: model deployment succeed".format(3+step))

    # Step5: Test model in JeddakPCC
    print("Step{}: Testing model in JeddakPCC".format(4+step))
    test_res = test_inference_model(
        AK=args.AK,
        SK=args.SK,
        model_id=publish_res["Result"]["ModelID"],
        silent=args.silent,
        service=args.service
    )
    print("Step{}: model test succeed with resp = {}".format(4+step, test_res))
    print("Step{}: Finished".format(5+step))
