"""
Copyright (year) Beijing Volcano Engine Technology Ltd.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import hashlib
import hmac
from urllib.parse import quote
import json
import requests
from datetime import datetime, timezone


# 以下参数视服务不同而不同，一个服务内通常是一致的
# Service = "pcc_test"
Version = "2024-12-24"
Region = "cn-beijing"
Host = "open.volcengineapi.com"
ContentType = "application/json"


def norm_query(params):
    query = ""
    for key in sorted(params.keys()):
        if isinstance(params[key], list):
            for k in params[key]:
                query = (
                        query + quote(key, safe="-_.~") + "=" + quote(k, safe="-_.~") + "&"
                )
        else:
            query = (query + quote(key, safe="-_.~") + "=" + quote(params[key], safe="-_.~") + "&")
    query = query[:-1]
    return query.replace("+", "%20")


# 第一步：准备辅助函数。
# sha256 非对称加密
def hmac_sha256(key: bytes, content: str):
    return hmac.new(key, content.encode("utf-8"), hashlib.sha256).digest()


# sha256 hash算法
def hash_sha256(content: str):
    return hashlib.sha256(content.encode("utf-8")).hexdigest()


def indent_text(text, spaces=4):
    prefix = ' ' * spaces
    return '\n'.join(prefix + line for line in text.splitlines())


# 第二步：签名请求函数
def request(method, date, query, header, ak, sk, action, body, silent=False, service="pcc"):
    # 第三步：创建身份证明。其中的 Service 和 Region 字段是固定的。ak 和 sk 分别代表
    # AccessKeyID 和 SecretAccessKey。同时需要初始化签名结构体。一些签名计算时需要的属性也在这里处理。
    # 初始化身份证明结构体
    credential = {
        "access_key_id": ak,
        "secret_access_key": sk,
        "service": service,
        "region": Region,
    }
    # 初始化签名结构体
    request_param = {
        "body": body,
        "host": Host,
        "path": "/",
        "method": method,
        "content_type": ContentType,
        "date": date,
        "query": {"Action": action, "Version": Version, **query},
    }
    if body is None:
        request_param["body"] = ""
    # 第四步：接下来开始计算签名。在计算签名前，先准备好用于接收签算结果的 signResult 变量，并设置一些参数。
    # 初始化签名结果的结构体
    x_date = request_param["date"].strftime("%Y%m%dT%H%M%SZ")
    short_x_date = x_date[:8]
    x_content_sha256 = hash_sha256(request_param["body"])
    sign_result = {
        "Host": request_param["host"],
        "X-Content-Sha256": x_content_sha256,
        "X-Date": x_date,
        "Content-Type": request_param["content_type"],
    }
    # 第五步：计算 Signature 签名。
    signed_headers_str = ";".join(
        ["content-type", "host", "x-content-sha256", "x-date"]
    )
    # signed_headers_str = signed_headers_str + ";x-security-token"
    canonical_request_str = "\n".join(
        [request_param["method"].upper(),
         request_param["path"],
         norm_query(request_param["query"]),
         "\n".join(
             [
                 "content-type:" + request_param["content_type"],
                 "host:" + request_param["host"],
                 "x-content-sha256:" + x_content_sha256,
                 "x-date:" + x_date,
             ]
         ),
         "",
         signed_headers_str,
         x_content_sha256,
         ]
    )

    # 打印正规化的请求用于调试比对
    if not silent:
        print(indent_text(canonical_request_str))
    hashed_canonical_request = hash_sha256(canonical_request_str)

    # 打印hash值用于调试比对
    if not silent:
        print(indent_text(hashed_canonical_request))
    credential_scope = "/".join([short_x_date, credential["region"], credential["service"], "request"])
    string_to_sign = "\n".join(["HMAC-SHA256", x_date, credential_scope, hashed_canonical_request])

    # 打印最终计算的签名字符串用于调试比对
    if not silent:
        print(indent_text(string_to_sign))
    k_date = hmac_sha256(credential["secret_access_key"].encode("utf-8"), short_x_date)
    k_region = hmac_sha256(k_date, credential["region"])
    k_service = hmac_sha256(k_region, credential["service"])
    k_signing = hmac_sha256(k_service, "request")
    signature = hmac_sha256(k_signing, string_to_sign).hex()

    sign_result["Authorization"] = "HMAC-SHA256 Credential={}, SignedHeaders={}, Signature={}".format(
        credential["access_key_id"] + "/" + credential_scope,
        signed_headers_str,
        signature,
    )
    formatted_str = json.dumps(sign_result, indent=2)  # 字典转为格式化字符串
    indented_str = '\n'.join(['    ' + line for line in formatted_str.splitlines()])
    if not silent:
        print(indented_str)
    header = {**header, **sign_result}
    # header = {**header, **{"X-Security-Token": SessionToken}}
    # 第六步：将 Signature 签名写入 HTTP Header 中，并发送 HTTP 请求。
    r = requests.request(method=method,
                         url="https://{}{}".format(request_param["host"], request_param["path"]),
                         headers=header,
                         params=request_param["query"],
                         data=request_param["body"],
                         )
    return r.json()


def set_tks_policy(AK: str, SK: str, uid: str, key_id: str, policy_id: str, silent: bool = False, service: str = "pcc"): 
    now = datetime.utcnow()
    header = {
        "AppID": uid,
        "Token": "",
        "Signature": "",
        "Timestamp": str(int(datetime.now(timezone.utc).timestamp()))
    }
    data = {
        "ID": key_id,
        "Range": "key",
        "Rules": """[[{\"tee\":\"none\",\"platform\":\"*\",\"policyID\":\"{policy_id}\",\"type\":\"RAPolicy\",\"optional\":false,\"inverse\":false}]]""".replace("{policy_id}", policy_id),
        "IssueToken": False, 
        "TokenLifetime": 360000,
        "AppID": "jsc_pcc",
        "Signature": "",
        "Token": ""
    }
    response_body = request(
        "POST", now, {}, header, AK, SK, "SetTksPolicy", json.dumps(data), silent, service
    )
    if not silent:
        print("    set_tks_policy succeed with resp =", response_body)
    return response_body


def publish_inference_model(
        AK: str, SK: str,
        model_name: str, model_version: str, model_description: str, model_baseline: str,
        tos_url: str,
        ring_id: str, key_id: str,
        encrypt_flag: bool, silent: bool = False,
        service: str = "pcc"):
    now = datetime.utcnow()
    data = {
        "ModelName": model_name,
        "ModelVersion": model_version,
        "ModelDescription": model_description,
        "ModelBaseline": model_baseline,
        "TosUrl": tos_url,
        "RingID": ring_id,
        "KeyID": key_id,
        "EncryptFlag": encrypt_flag,
    }
    response_body = request(
        "POST", now, {}, {}, AK, SK, "PublishInferenceModel", json.dumps(data), silent, service
    )
    if not silent:
        print("    publish_inference_model succeed with resp =", response_body)
    return response_body


def deploy_inference_model(AK: str, SK: str, model_id: str, cu_type: str, replica_count: int, inference_engine: str, engine_conf: dict, silent: bool = False, service: str = "pcc"):
    assert cu_type in ["Basic", "Advanced", "Enterprise"], "cu_type must be one of Basic, Advanced, Enterprise"
    assert inference_engine in ["vllm", "sglang"], "inference_engine must be one of vllm, sglang"
    assert isinstance(replica_count, int) and replica_count > 0, "replica_count must be a positive integer"

    if cu_type == "Basic":
        assert inference_engine == "vllm", "Basic cu_type only supports vllm inference_engine"
    elif cu_type == "Enterprise":
        assert inference_engine == "sglang", "Enterprise cu_type only supports sglang inference_engine"

    now = datetime.utcnow()
    data = {
        "ModelID": model_id,
        "CuType": cu_type,
        "ReplicaCount": replica_count,
        "InferenceEngine": inference_engine,
        "EngineConf": engine_conf,
        "EnableEic": False,
        "EnableHpa": False,
        "EnableAntiDirt": False
    }
    response_body = request(
        "POST", now, {}, {}, AK, SK, "DeployInferenceModel", json.dumps(data), silent, service
    )
    if not silent:
        print("    deploy_inference_model succeed with resp =", response_body)
    return response_body


def get_engine_conf(AK: str, SK: str, cu_type, inference_engine, model_id: str, silent: bool = False, service: str = "pcc"):
    now = datetime.utcnow()
    data = {
        "CuType": cu_type,
        "InferenceEngine": inference_engine,
        "ModelId": model_id
    }
    response_body = request(
        "POST", now, {}, {}, AK, SK, "GetModelEngineConf", json.dumps(data), silent, service
    )
    # print(response_body)
    return response_body


def list_inference_model(AK: str, SK: str, silent: bool = False, service: str = "pcc"):
    now = datetime.utcnow()
    data = {
        "PageNumber": 1,
        "PageSize": 1000
    }
    response_body = request(
        "POST", now, {}, {}, AK, SK, "ListInferenceModel", json.dumps(data), silent, service
    )
    # print(response_body)
    return response_body


def test_inference_model(AK: str, SK: str, model_id: str, silent: bool = False, service: str = "pcc"):
    now = datetime.utcnow()
    data = {
        "ModelID": model_id
    }
    response_body = request(
        "POST", now, {}, {}, AK, SK, "TestInferenceModel", json.dumps(data), silent, service
    )
    if not silent:
        print("    test_inference_model succeed with resp =", response_body)
    return response_body
