|  | @@ -0,0 +1,124 @@
 | 
											
												
													
														|  | 
 |  | +import websocket
 | 
											
												
													
														|  | 
 |  | +import json
 | 
											
												
													
														|  | 
 |  | +import _thread
 | 
											
												
													
														|  | 
 |  | +import ssl
 | 
											
												
													
														|  | 
 |  | +from urllib.parse import urlparse, urlencode
 | 
											
												
													
														|  | 
 |  | +from wsgiref.handlers import format_date_time
 | 
											
												
													
														|  | 
 |  | +from datetime import datetime
 | 
											
												
													
														|  | 
 |  | +from time import mktime
 | 
											
												
													
														|  | 
 |  | +import hashlib
 | 
											
												
													
														|  | 
 |  | +import hmac
 | 
											
												
													
														|  | 
 |  | +import base64
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +"""
 | 
											
												
													
														|  | 
 |  | +星火语言模型API
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +@param system: 背景信息
 | 
											
												
													
														|  | 
 |  | +    - Type: String
 | 
											
												
													
														|  | 
 |  | +    - Example: system="现在扮演李白,你豪情万丈,狂放不羁;接下来请用李白的口吻和用户对话。"
 | 
											
												
													
														|  | 
 |  | +@param history: 历史记录示例
 | 
											
												
													
														|  | 
 |  | +    - Type: List[dict]
 | 
											
												
													
														|  | 
 |  | +    - Example:
 | 
											
												
													
														|  | 
 |  | +        history=[
 | 
											
												
													
														|  | 
 |  | +            {"role": "user", "content": "你给我写一首诗,以兄弟为题不包含兄弟2字,七步内成诗,不行杀了你"},
 | 
											
												
													
														|  | 
 |  | +            {"role": "assistant", "content": "君若问兄弟,我以诗代言。江山共长天,一脉同根源。风雨共行路,携手笑苍天。豪情万丈志,生死共肩连。肝胆相照映,此心永不偏。纵有千般难,并肩共前缘。人生如梦幻,唯愿共度年。此诗赠君子,兄弟情谊传。"},
 | 
											
												
													
														|  | 
 |  | +        ]
 | 
											
												
													
														|  | 
 |  | +@param query: 用户输入字符串
 | 
											
												
													
														|  | 
 |  | +    - Type: String
 | 
											
												
													
														|  | 
 |  | +    - Example: query="不好意思,你超过7步了,拖下去砍了。"
 | 
											
												
													
														|  | 
 |  | +"""
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class ChatClient:
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, APPID, APIKey, APISecret, gpt_url, domain, query, history=None, system=None):
 | 
											
												
													
														|  | 
 |  | +        self.ws = None
 | 
											
												
													
														|  | 
 |  | +        self.APPID = APPID
 | 
											
												
													
														|  | 
 |  | +        self.APIKey = APIKey
 | 
											
												
													
														|  | 
 |  | +        self.APISecret = APISecret
 | 
											
												
													
														|  | 
 |  | +        self.gpt_url = gpt_url
 | 
											
												
													
														|  | 
 |  | +        self.domain = domain
 | 
											
												
													
														|  | 
 |  | +        self.query = query
 | 
											
												
													
														|  | 
 |  | +        self.system = system
 | 
											
												
													
														|  | 
 |  | +        # 历史记录应该是一个包含角色和内容的字典列表
 | 
											
												
													
														|  | 
 |  | +        self.history = history or []
 | 
											
												
													
														|  | 
 |  | +        self.response = ""
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def create_url(self):
 | 
											
												
													
														|  | 
 |  | +        host = urlparse(self.gpt_url).netloc
 | 
											
												
													
														|  | 
 |  | +        path = urlparse(self.gpt_url).path
 | 
											
												
													
														|  | 
 |  | +        now = datetime.now()
 | 
											
												
													
														|  | 
 |  | +        date = format_date_time(mktime(now.timetuple()))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        signature_origin = "host: " + host + "\n" + "date: " + date + "\n" + "GET " + path + " HTTP/1.1"
 | 
											
												
													
														|  | 
 |  | +        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
 | 
											
												
													
														|  | 
 |  | +                                 digestmod=hashlib.sha256).digest()
 | 
											
												
													
														|  | 
 |  | +        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
 | 
											
												
													
														|  | 
 |  | +        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        v = {"authorization": authorization, "date": date, "host": host}
 | 
											
												
													
														|  | 
 |  | +        return self.gpt_url + '?' + urlencode(v)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def on_message(self, ws, message):
 | 
											
												
													
														|  | 
 |  | +        data = json.loads(message)
 | 
											
												
													
														|  | 
 |  | +        code = data['header']['code']
 | 
											
												
													
														|  | 
 |  | +        if code == 0:
 | 
											
												
													
														|  | 
 |  | +            choices = data["payload"]["choices"]
 | 
											
												
													
														|  | 
 |  | +            content = choices["text"][0]["content"]
 | 
											
												
													
														|  | 
 |  | +            self.response += content
 | 
											
												
													
														|  | 
 |  | +            if choices["status"] == 2:
 | 
											
												
													
														|  | 
 |  | +                ws.close()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def on_error(self, error):
 | 
											
												
													
														|  | 
 |  | +        print(f"WebSocket error: {error}")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def on_close(self):
 | 
											
												
													
														|  | 
 |  | +        print("WebSocket connection closed")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def on_open(self, ws):
 | 
											
												
													
														|  | 
 |  | +        def run(*args):
 | 
											
												
													
														|  | 
 |  | +            data = json.dumps(self.gen_params())
 | 
											
												
													
														|  | 
 |  | +            ws.send(data)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        _thread.start_new_thread(run, ())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def gen_params(self):
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        生成请求参数,包括历史记录和系统信息
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        texts = []
 | 
											
												
													
														|  | 
 |  | +        if self.system:
 | 
											
												
													
														|  | 
 |  | +            texts.append({"role": "system", "content": self.system})
 | 
											
												
													
														|  | 
 |  | +        texts.extend(self.history)
 | 
											
												
													
														|  | 
 |  | +        texts.append({"role": "user", "content": self.query})
 | 
											
												
													
														|  | 
 |  | +        return {
 | 
											
												
													
														|  | 
 |  | +            "header": {
 | 
											
												
													
														|  | 
 |  | +                "app_id": self.APPID,
 | 
											
												
													
														|  | 
 |  | +                "uid": "1234",
 | 
											
												
													
														|  | 
 |  | +            },
 | 
											
												
													
														|  | 
 |  | +            "parameter": {
 | 
											
												
													
														|  | 
 |  | +                "chat": {
 | 
											
												
													
														|  | 
 |  | +                    "domain": self.domain,
 | 
											
												
													
														|  | 
 |  | +                    "temperature": 0.5,
 | 
											
												
													
														|  | 
 |  | +                    "max_tokens": 4096,
 | 
											
												
													
														|  | 
 |  | +                    "auditing": "default",
 | 
											
												
													
														|  | 
 |  | +                }
 | 
											
												
													
														|  | 
 |  | +            },
 | 
											
												
													
														|  | 
 |  | +            "payload": {
 | 
											
												
													
														|  | 
 |  | +                "message":
 | 
											
												
													
														|  | 
 |  | +                    {
 | 
											
												
													
														|  | 
 |  | +                        "text": texts
 | 
											
												
													
														|  | 
 |  | +                    }
 | 
											
												
													
														|  | 
 |  | +            }
 | 
											
												
													
														|  | 
 |  | +        }
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def start(self):
 | 
											
												
													
														|  | 
 |  | +        websocket.enableTrace(False)
 | 
											
												
													
														|  | 
 |  | +        self.ws = websocket.WebSocketApp(self.create_url(),
 | 
											
												
													
														|  | 
 |  | +                                         on_message=lambda ws, msg: self.on_message(ws, msg),
 | 
											
												
													
														|  | 
 |  | +                                         on_error=lambda msg: self.on_error(msg),
 | 
											
												
													
														|  | 
 |  | +                                         on_close=self.on_close,
 | 
											
												
													
														|  | 
 |  | +                                         on_open=lambda ws: self.on_open(ws))  # 使用 lambda 来确保 ws 参数传递
 | 
											
												
													
														|  | 
 |  | +        self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
 | 
											
												
													
														|  | 
 |  | +        return self.response
 |