from typing import List
from .error import CryptoRAGError, CryptoRAGErrorCode, CryptoRAGErrorMsg

import bytedance.jeddak_secure_channel as jsc
import msgpack
import warnings
import requests



class CryptoRAGClient:
    api_path = {
        "EmbedEncPlainChannel": "/api/rag/v1/crypto/dpe/enc_embedding",
        "EmbedEncSecureChannel": "/api/rag/v2/crypto/dpe/enc_embedding",
    }
    def __init__(self, account_id: str, rag_app_id: str, server_url: str, secure_channel_client: jsc.Client):
        """
        :param account_id: str, 火山账号ID
        :param rag_app_id: str, RAG加密服务应用ID
        :param server_url: str, 服务端URL
        :param secure_channel_client: jsc.Client, 机密通信client
        """
        self.account_id = str(account_id)
        self.rag_app_id = rag_app_id
        self.server_url = server_url
        self.secure_channel_client = secure_channel_client
    
    
    def embeddings_check(self, embeddings: List[List[float]]):
        if not embeddings or not embeddings[0]:
            raise CryptoRAGError(CryptoRAGErrorCode.INVALID_PARAM, "Embeddings cannot be empty")
        elif len(embeddings) > 100:
            raise CryptoRAGError(CryptoRAGErrorCode.INVALID_PARAM, "The first dimension cannot exceed 100")
        elif len(embeddings[0]) > 8192:
            raise CryptoRAGError(CryptoRAGErrorCode.INVALID_PARAM, "The second dimension cannot exceed 8192")
        
        if not all(isinstance(x, float) for vec in embeddings for x in vec):
            raise CryptoRAGError(CryptoRAGErrorCode.INVALID_PARAM, "Element type should be float")
        
    def req_embed_encryption_plain_channel(self, embeddings): 
        """
        :param embeddings: List[List[float]]
        :return: List[List[float]]
        """
        self.embeddings_check(embeddings)
        warnings.warn("Invoke embedding encryption through plain channel")

        req = {
            "RagAppID": self.rag_app_id,
            "DenseVectors": embeddings
        }
        resp = requests.post(self.server_url + self.api_path["EmbedEncPlainChannel"], 
                             json=req, headers={"X-Top-Account-Id": self.account_id})
        resp_json = resp.json()
        if resp.status_code != 200:
            try:
                error_code = resp_json["ResponseMetadata"]["Error"]["Code"]
                error_msg = resp_json["ResponseMetadata"]["Error"]["Message"]
            except Exception as e:
                raise KeyError(f"Failed to extract error metadata from response: {resp.text}") from e
            raise CryptoRAGError(error_code, f"{error_msg}")
        else:
            try:
                encrypted_embeddings = resp_json["Result"]["CipherDenseVectors"]
            except Exception as e:
                raise KeyError(f"Failed to extract result from response: {resp.text}") from e
        return encrypted_embeddings
        
    def req_embed_encryption_secure_channel(self, embeddings: List[List[float]], disable_ra_during_downgrade: bool = True):
        """
        :param embeddings: List[List[float]]
        :return: List[List[float]]
        """
        self.embeddings_check(embeddings)
        try:
            buf: bytes = msgpack.packb(embeddings) # type: ignore
            encrypted_msg, enc_key = self.secure_channel_client.encrypt_with_response(buf)
        except Exception as e:
            if disable_ra_during_downgrade:
                return self.req_embed_encryption_plain_channel(embeddings)
            else:
                raise CryptoRAGError(CryptoRAGErrorCode.JSC_ERROR, CryptoRAGErrorMsg.JSC_ERROR) from e
        
        req = {
            "RagAppID": self.rag_app_id,
            "DenseVectors": encrypted_msg
        }
        resp = requests.post(self.server_url + self.api_path["EmbedEncSecureChannel"], 
                             json=req, headers={"X-Top-Account-Id": self.account_id})
        resp_json = resp.json()
        
        if resp.status_code != 200:
            try:
                error_code = resp_json["ResponseMetadata"]["Error"]["Code"]
                error_msg = resp_json["ResponseMetadata"]["Error"]["Message"]
            except Exception as e:
                raise KeyError(f"Failed to get error metadata: {resp.text}") from e
            
            if int(error_code) == CryptoRAGErrorCode.JSC_ERROR and disable_ra_during_downgrade:
                return self.req_embed_encryption_plain_channel(embeddings)
            
            raise CryptoRAGError(error_code, f"{error_msg}")
        try:
            encrypted_vectors = resp_json["Result"]["CipherDenseVectors"]
            vector_bytes: bytes = enc_key.decrypt(encrypted_vectors)
            encrypted_dense_vectors = msgpack.unpackb(vector_bytes)
        except Exception as e:
            if disable_ra_during_downgrade:
                return self.req_embed_encryption_plain_channel(embeddings)
            else:
                raise CryptoRAGError(CryptoRAGErrorCode.JSC_ERROR, CryptoRAGErrorMsg.JSC_ERROR) from e
        return encrypted_dense_vectors
    