"""
密钥管理服务 (Trusted Key Service, TKS) 的客户端.
TCA采用背调模式
"""

__all__ = ["TksClient", "TksConfig"]

import base64
import datetime
import hashlib
import json
import secrets
from dataclasses import dataclass, field

from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from typing_extensions import Any, Optional, Tuple

from .. import error
from ..crypto import AesKey, PrivateKey
from ..log import logger
from .eps import EpsClient
from .ras import RasClient
from .tks_models import (
    AppOptions,
    AttestationResult,
    CertEvidence,
    ChallengeResponse,
    ClientChallenge,
    ExportationResult,
    KeyExportation,
    KeyImportation,
    KeyIndexWithVersion,
    RAEvidence,
    SecurityAttestation,
    TksAuthHeaders,
    TksResponse,
    ZTIEvidence,
)


@dataclass
class TksConfig:
    """访问密钥管理服务所需的配置."""

    tks_url: str = "http://localhost:6789"
    """TKS 服务的 HTTP 地址."""

    tks_timezone: float = 8.0
    """TKS 服务使用的时区. 一般不需要修改."""

    tks_enable_server_auth: bool = True
    """
    是否对 TKS 服务提供的证明报告进行验证.
    默认为是.
    """

    tks_enable_bi_auth: bool = True
    """
    是否将客户端的证明报告发送给 TKS 服务.
    默认为是.
    """


def sign_message(message: bytes, key: str) -> str:
    """
    用指定的密钥对消息签名, 并编码为 Base64.

    Args:
        key: PEM 格式的密钥.
    """

    try:
        signing_key = PrivateKey.from_private_pem(key)
        signature = signing_key.sign(message)
        return base64.standard_b64encode(signature).decode()
    except Exception as e:
        logger.critical("Message signing failed")
        raise error.EncryptionError("Message signing failed") from e


def app_headers(message: bytes, app: AppOptions, config: TksConfig) -> TksAuthHeaders:
    """
    根据特定于 App 的认证选项, 设置 `AppID`, `Timestamp`, `Signature`, 及 `Token` 等请求头.
    """

    tz = datetime.timezone(datetime.timedelta(hours=config.tks_timezone))
    timestamp = str(int(datetime.datetime.now(tz).timestamp()))
    headers = TksAuthHeaders(AppID=app.id, Timestamp=timestamp)

    if app.signing_key:
        message = message + app.id.encode() + timestamp.encode()
        message = hashlib.sha256(message).digest()
        headers["Signature"] = sign_message(message, app.signing_key)
    if app.zti_token:
        headers["Token"] = app.zti_token

    return headers


def tks_request(endpoint: str, body: Any, app: Optional[AppOptions], config: TksConfig) -> Any:
    """
    向 TKS 的指定端点发送 HTTP POST 请求.

    Args:
        body: 可序列化为 JSON 的请求体对象.
        app: 特定于 App 的认证选项, 用于设置 `AppID`, `Timestamp`, `Signature`, 及 `Token` 请求头.
    Returns:
        TKS 响应体中的 `Result` 字段.
    """
    import requests

    body_json = json.dumps(body).encode()
    headers = {"Content-Type": "application/json"}
    if app:
        tks_headers = app_headers(body_json, app, config)
        headers.update(**tks_headers)

    try:
        response = requests.post(config.tks_url + endpoint, headers=headers, data=body_json)
    except Exception as e:
        logger.critical(f"Network error: service={config.tks_url} {endpoint=}")
        raise error.NetworkError("TKS", config.tks_url, endpoint) from e

    try:
        response_json: TksResponse[Any] = response.json()
    except Exception as e:
        logger.critical(f"Response is not JSON: service={config.tks_url} {endpoint=}")
        raise error.ServiceError("TKS", config.tks_url, endpoint, "not JSON") from e

    try:
        return response_json["Result"]  # pyright: ignore [reportTypedDictNotRequiredAccess]
    except KeyError as e:
        try:
            message = str(
                response_json["ResponseMetadata"][
                    "Error"
                ]  # pyright: ignore [reportTypedDictNotRequiredAccess]
            )
        except KeyError:
            message = str(response_json)
        logger.critical(f"Error response: service={config.tks_url} {endpoint=} {message=}")
        raise error.ServiceError("TKS", config.tks_url, endpoint, message) from e


AES_KEY_LEN = 32

AES_NONCE_LEN = 12

AES_MAC_LEN = 16

ALLOWED_KEY_LEN = [16, 32]


def key_exchange(server_pub_key: str) -> Tuple[bytes, bytes]:
    """
    按照 ECDSA-SECP384R1 进行密钥协商.

    Returns:
        客户端的公钥, 及双方的对称密钥.
    """

    try:
        server_key = serialization.load_pem_public_key(server_pub_key.encode())
        if not isinstance(server_key, ec.EllipticCurvePublicKey):
            raise RuntimeError()

        private_key = ec.generate_private_key(ec.SECP384R1())
        shared_secret = private_key.exchange(ec.ECDH(), server_key)

        client_pub_key = private_key.public_key().public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo,
        )

        hkdf = HKDF(algorithm=hashes.SHA256(), length=AES_KEY_LEN, salt=None, info=None)
        shared_key = hkdf.derive(shared_secret)

        return client_pub_key, shared_key
    except Exception as e:
        logger.critical("Key exchange failed")
        raise error.EncryptionError("Key exchange failed") from e


def encrypt_key(symmetric_key: bytes, plaintext: bytes) -> bytes:
    """
    用对称密钥对上传到 TKS 数据进行加密. 加密结果由随机数, 密文, 和校验码连接而成,
    随机数长度为 `AES_NONCE_LEN`, 校验码长度为 `AES_MAC_LEN`, 密文长度等于 `plaintext` 参数的长度.

    Args:
        symmetric_key: 与 TKS 协商的对称密钥.
        plaintext: 需要上传到 TKS 的数据.
    Returns:
        加密的数据.
    """
    nonce, ciphertext, mac = AesKey(symmetric_key).encrypt(plaintext)

    return nonce + ciphertext + mac


def decrypt_key(symmetric_key: bytes, message: bytes) -> bytes:
    """
    用对称密钥对 TKS 返回的数据进行解密. `encrypt_key` 的反函数.

    Args:
        symmetric_key: 与 TKS 协商的对称密钥.
        message: TKS 返回的数据.
    Returns:
        解密的数据.
    """
    nonce, ciphertext, mac = (
        message[:AES_NONCE_LEN],
        message[AES_NONCE_LEN:-AES_MAC_LEN],
        message[-AES_MAC_LEN:],
    )
    if (
        len(nonce) != AES_NONCE_LEN
        or len(ciphertext) not in ALLOWED_KEY_LEN
        or len(mac) != AES_MAC_LEN
    ):
        raise error.DecryptionError("Invalid data length")

    return AesKey(symmetric_key).decrypt(nonce, ciphertext, mac)


def respond_challenge(
    challenge: ClientChallenge, app: AppOptions, config: TksConfig, eps_client: EpsClient
) -> ChallengeResponse:
    """对 TKS 服务下发的的 `ClientChallenge` 进行响应."""

    result = ChallengeResponse(ClientChall=challenge)

    nonce = challenge["NonceDown"]
    if config.tks_enable_bi_auth:
        quote = eps_client.get_quote(nonce)
        result["RAEvidence"] = RAEvidence(TEEType="TDX", Report=quote)

    if app.signing_key:
        signature = sign_message(nonce.encode(), app.signing_key)
        result["CertEvidence"] = CertEvidence(Signature=signature)

    if app.zti_token:
        result["ZTIEvidence"] = ZTIEvidence(Token=app.zti_token)

    return result


@dataclass(eq=False)
class TksClient:
    """
    密钥管理服务 (Trusted Key Service, TKS) 的客户端.

    本类的所有方法是线程安全的.
    """

    config: TksConfig
    """访问密钥管理服务所需的配置."""

    ras_client: Optional[RasClient] = field(repr=False, default=None)

    eps_client: Optional[EpsClient] = field(repr=False, default=None)

    def _attest(self, bi_auth: bool) -> Tuple[str, Optional[ClientChallenge]]:
        """
        访问 TKS 服务的 SecurityAttestation 端点来进行远程验证.

        Args:
            bi_auth: 是否需要对客户端进行验证 (双向验证).
                当 `config.enable_bi_auth` 为假时, 此参数无效.
        Returns:
            TKS 服务的公钥, 可用于 `key_exchange` 函数, 以及可选的客户端挑战 `ClientChallenge`.
        """
        nonce = secrets.token_hex(32)
        params = SecurityAttestation(
            NonceUp=nonce, BiAuth=bi_auth and self.config.tks_enable_bi_auth
        )
        result: AttestationResult = tks_request(
            "/api/tks/v1/security/attest", params, None, self.config
        )

        ra_type = result["RAType"]
        quote = result["Report"]
        server_pub_key = base64.standard_b64decode(result["DHParam"]).decode()
        client_challenge = result.get("Challenge")

        if ra_type not in ["TDX", "coco"]:
            logger.warning(f"Unknown {ra_type=} in attestation")

        if self.config.tks_enable_server_auth:
            if self.ras_client is None:
                raise error.ConfigError(
                    self.config, "Cannot evaluate attestation report of TKS: RAS is not configured"
                )
            token = self.ras_client.get_attestation_evaluation("TDX", quote)
            logger.info(f"Attestation: token={token}")

        return server_pub_key, client_challenge

    def export_key(self, app: AppOptions, ring_id: str, key_id: str) -> Tuple[bytes, int]:
        """
        访问 TKS 服务的 KeyExportation 端点来下载并解密密钥.

        Returns:
            解密后的密钥, 及密钥版本信息.
        """
        server_pub_key, client_challenge = self._attest(bi_auth=True)

        client_pub_key, shared_key = key_exchange(server_pub_key)

        params = KeyExportation(
            AppID=app.id,
            RingID=ring_id,
            KeyID=key_id,
            DHParam=base64.standard_b64encode(client_pub_key).decode(),
        )
        if client_challenge:
            if self.eps_client is None:
                raise error.ConfigError(
                    self.config, "Cannot respond to client challenge: EPS is not configured"
                )
            challenge_response = respond_challenge(
                client_challenge, app, self.config, self.eps_client
            )
            params.update(challenge_response)  # type: ignore

        result: ExportationResult = tks_request("/api/tks/v1/key/export", params, app, self.config)

        logger.info("Export successful")

        key = base64.standard_b64decode(result["Key"])

        return decrypt_key(shared_key, key), result["Version"]

    def import_key(self, app: AppOptions, ring_id: str, key_id: str, key: bytes) -> int:
        """
        访问 TKS 服务的 KeyImportation 端点来加密并上传密钥.

        Args:
            key: 长度为 16 或 32 的密钥.
        Returns:
            新密钥的版本.
        """
        if len(key) not in ALLOWED_KEY_LEN:
            raise error.ParamError("key", f"Key length {len(key)} is not acceptable to TKS")

        server_pub_key, _ = self._attest(bi_auth=False)

        client_pub_key, shared_key = key_exchange(server_pub_key)

        params = KeyImportation(
            RingID=ring_id,
            KeyID=key_id,
            Key=base64.standard_b64encode(encrypt_key(shared_key, key)).decode(),
            DHParam=base64.standard_b64encode(client_pub_key).decode(),
        )

        result: KeyIndexWithVersion = tks_request(
            "/api/tks/v1/key/import", params, app, self.config
        )

        logger.info(f"Import successful: {result}")

        return result["Version"]
