import requests

import bytedance.jeddak_secure_channel as jsc
import os
from flask import Response
import json
import base64
import argparse
import json


def getenv_with_empty_check(key, default):
    value = os.getenv(key)
    return default if not value else value

jsc_env = getenv_with_empty_check("jsc_env", "Online")
if jsc_env == "Online":
    secure_channel_config = jsc.ClientConfig.from_file("configs/client_config_online.json")
elif jsc_env == "PPE":
    secure_channel_config = jsc.ClientConfig.from_file("configs/client_config_ppe.json")
elif jsc_env == "veStack":
    secure_channel_config = jsc.ClientConfig.from_file("configs/client_config_vestack.json")
elif jsc_env == "Pre":
    secure_channel_config = jsc.ClientConfig.from_file("configs/client_config_pre.json")
secure_channel_client = jsc.Client(secure_channel_config)


class VLLMClient:
    def __init__(self, base_url: str, authorization_token: str = None):
        """
        初始化 vLLM 客户端。
        
        :param base_url: vLLM 服务的基础 URL，例如 "http://localhost:8000"
        """
        self.base_url = base_url.rstrip("/")
        self.authorization_token = authorization_token
        self.headers = None
        if self.authorization_token:
            self.headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.authorization_token}"
            }
        else:
            self.headers = {"Content-Type": "application/json"}

    def generate_completion(self, model: str, prompt: str, max_tokens: int = 100, temperature: float = 1.0):
        """
        调用 /v2/completions 接口生成文本补全。
        
        :param model: 模型名称。
        :param prompt: 输入的提示文本。
        :param max_tokens: 生成的最大 token 数。
        :param temperature: 控制生成文本的随机性（越低越确定）。
        :return: 完成的生成结果。
        """
        
        url = f"{self.base_url}/v2/completions"
        payload = {
            "model": model,  # 模型名称
            "prompt": prompt,
            "max_tokens": max_tokens,
            "temperature": temperature,
        }

        # 加密数据并发送
        encrypted_prompt, encrypt_key = secure_channel_client.encrypt_with_response(payload["prompt"])
        payload["prompt"] = encrypted_prompt  # 敏感字段

        try:
            response = requests.post(url, json=payload, headers=self.headers)
            response.raise_for_status()
            data = response.json()  # 返回 JSON 格式的响应
            for elm in data['choices']:
                elm['text'] = encrypt_key.decrypt(elm['text']).decode()
            return Response(response=json.dumps(data), status=200, content_type="application/json")
        except requests.RequestException as e:
            print(f"Error during request: {e}")
            return None

    def generate_chat_completion(self, model: str, messages: list, max_tokens: int = 100, temperature: float = 1.0):
        """
        调用 /v2/chat/completions 接口生成聊天补全。
        
        :param model: 模型名称。
        :param messages: 聊天消息列表。
        :param max_tokens: 生成的最大 token 数。
        :param temperature: 控制生成文本的随机性（越低越确定）。
        :return: 完成的生成结果。
        """
        url = f"{self.base_url}/v2/chat/completions"
        payload = {
            "model": model,  # 模型名称
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
        }

        # 加密数据并发送
        encrypted_messages, encrypt_key = secure_channel_client.encrypt_with_response(json.dumps(payload["messages"]))
        payload["messages"] = encrypted_messages  # 敏
        
        try:
            response = requests.post(url, json=payload, headers=self.headers)
            response.raise_for_status()
            data = response.json()  # 返回 JSON 格式的响应
            for elm in data['choices']:
                elm['message'] = json.loads(encrypt_key.decrypt(elm['message']))
            return Response(response=json.dumps(data), status=200, content_type="application/json")
        except requests.RequestException as e:
            print(f"Error during request: {e}")
            return None

    def generate_completion_stream(self, model: str, prompt: str, max_tokens: int = 100, temperature: float = 1.0):
        """
        调用 /v2/completions 接口以流式方式生成文本补全。
        
        :param model: 模型名称。
        :param prompt: 输入的提示文本。
        :param max_tokens: 生成的最大 token 数。
        :param temperature: 控制生成文本的随机性（越低越确定）。
        :return: 完成的生成结果。
        """
        url = f"{self.base_url}/v2/completions"
        payload = {
            "model": model,  # 模型名称
            "prompt": prompt,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stream": True  # 启用流式响应
        }

        # 加密数据并发送
        encrypted_prompt, encrypt_key = secure_channel_client.encrypt_with_response(payload["prompt"])
        payload["prompt"] = encrypted_prompt  # 敏感字段
        try:
            # 启用流式响应
            with requests.post(url, json=payload, headers=self.headers, stream=True) as response:
                response.raise_for_status()
                for chunk in response.iter_lines():
                    if chunk:
                        data = encrypt_key.decrypt(base64.b64decode(chunk))
                        print(f"Received chunk: {data}")
                        # 你可以在这里处理每个 chunk，例如解析 JSON 或直接打印
        except requests.RequestException as e:
            print(f"Error during request: {e}")

    def generate_chat_completion_stream(self, model: str, messages: list, max_tokens: int = 100, temperature: float = 1.0):
        """
        调用 /v2/chat/completions 接口以流式方式生成聊天补全。
        
        :param model: 模型名称。
        :param messages: 聊天消息列表。
        :param max_tokens: 生成的最大 token 数。
        :param temperature: 控制生成文本的随机性（越低越确定）。
        :return: 完成的生成结果。
        """
        url = f"{self.base_url}/v2/chat/completions"
        payload = {
            "model": model,  # 模型名称
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stream": True  # 启用流式响应
        }

        # 加密数据并发送
        encrypted_messages, encrypt_key = secure_channel_client.encrypt_with_response(json.dumps(payload["messages"]))
        payload["messages"] = encrypted_messages  # 敏
        try:
            # 启用流式响应
            with requests.post(url, json=payload, headers=self.headers, stream=True) as response:
                response.raise_for_status()
                for chunk in response.iter_lines():
                    if chunk:
                        data = encrypt_key.decrypt(base64.b64decode(chunk))
                        print(f"Received chunk: {data}")
                        # 你可以在这里处理每个 chunk，例如解析 JSON 或直接打印
        except requests.RequestException as e:
            print(f"Error during request: {e}")

# 示例用法
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="这是一个简单的命令行参数解析器示例")

    # 添加参数
    parser.add_argument('--ip', '-i', type=str, default="localhost", help='服务的IP地址')
    parser.add_argument('--port', '-p', type=str, default='8080', help='服务的端口')
    parser.add_argument('--endpoint', '-e', type=str, help='服务的接入点')
    parser.add_argument('--model_name', '-m', type=str, default="decrypted_model/Qwen32B", help='模型名称')
    parser.add_argument('--authorization_token', '-auth', type=str, default=None, help='授权令牌')
    args = parser.parse_args()

    base_url = f"http://{args.ip}:{args.port}"
    if args.endpoint:
        base_url = args.endpoint

    client = VLLMClient(base_url, args.authorization_token)  # 替换为你的 vLLM 服务地址
    model_name = args.model_name
    prompt = "Once upon a time in a faraway land,"
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Tell me a story about San Francisco."}
    ]

    response_completion = client.generate_completion(model_name, prompt, max_tokens=50, temperature=0.7)
    print("Generated Text:", response_completion.json)
    response_chat_completion = client.generate_chat_completion(model_name, messages, max_tokens=100, temperature=0.7)
    print("Generated Chat Response:", response_chat_completion.json)

    client.generate_completion_stream(model_name, prompt, max_tokens=50, temperature=0.7)

    client.generate_chat_completion_stream(model_name, messages, max_tokens=100, temperature=0.7)