|
@@ -0,0 +1,124 @@
|
|
|
+import websocket
|
|
|
+import json
|
|
|
+import _thread
|
|
|
+import ssl
|
|
|
+from urllib.parse import urlparse, urlencode
|
|
|
+from wsgiref.handlers import format_date_time
|
|
|
+from datetime import datetime
|
|
|
+from time import mktime
|
|
|
+import hashlib
|
|
|
+import hmac
|
|
|
+import base64
|
|
|
+
|
|
|
+"""
|
|
|
+星火语言模型API
|
|
|
+
|
|
|
+@param system: 背景信息
|
|
|
+ - Type: String
|
|
|
+ - Example: system="现在扮演李白,你豪情万丈,狂放不羁;接下来请用李白的口吻和用户对话。"
|
|
|
+@param history: 历史记录示例
|
|
|
+ - Type: List[dict]
|
|
|
+ - Example:
|
|
|
+ history=[
|
|
|
+ {"role": "user", "content": "你给我写一首诗,以兄弟为题不包含兄弟2字,七步内成诗,不行杀了你"},
|
|
|
+ {"role": "assistant", "content": "君若问兄弟,我以诗代言。江山共长天,一脉同根源。风雨共行路,携手笑苍天。豪情万丈志,生死共肩连。肝胆相照映,此心永不偏。纵有千般难,并肩共前缘。人生如梦幻,唯愿共度年。此诗赠君子,兄弟情谊传。"},
|
|
|
+ ]
|
|
|
+@param query: 用户输入字符串
|
|
|
+ - Type: String
|
|
|
+ - Example: query="不好意思,你超过7步了,拖下去砍了。"
|
|
|
+"""
|
|
|
+
|
|
|
+
|
|
|
+class ChatClient:
|
|
|
+ def __init__(self, APPID, APIKey, APISecret, gpt_url, domain, query, history=None, system=None):
|
|
|
+ self.ws = None
|
|
|
+ self.APPID = APPID
|
|
|
+ self.APIKey = APIKey
|
|
|
+ self.APISecret = APISecret
|
|
|
+ self.gpt_url = gpt_url
|
|
|
+ self.domain = domain
|
|
|
+ self.query = query
|
|
|
+ self.system = system
|
|
|
+ # 历史记录应该是一个包含角色和内容的字典列表
|
|
|
+ self.history = history or []
|
|
|
+ self.response = ""
|
|
|
+
|
|
|
+ def create_url(self):
|
|
|
+ host = urlparse(self.gpt_url).netloc
|
|
|
+ path = urlparse(self.gpt_url).path
|
|
|
+ now = datetime.now()
|
|
|
+ date = format_date_time(mktime(now.timetuple()))
|
|
|
+
|
|
|
+ signature_origin = "host: " + host + "\n" + "date: " + date + "\n" + "GET " + path + " HTTP/1.1"
|
|
|
+ signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
|
|
+ digestmod=hashlib.sha256).digest()
|
|
|
+ signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
|
|
+
|
|
|
+ authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
|
|
+ authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
|
|
+
|
|
|
+ v = {"authorization": authorization, "date": date, "host": host}
|
|
|
+ return self.gpt_url + '?' + urlencode(v)
|
|
|
+
|
|
|
+ def on_message(self, ws, message):
|
|
|
+ data = json.loads(message)
|
|
|
+ code = data['header']['code']
|
|
|
+ if code == 0:
|
|
|
+ choices = data["payload"]["choices"]
|
|
|
+ content = choices["text"][0]["content"]
|
|
|
+ self.response += content
|
|
|
+ if choices["status"] == 2:
|
|
|
+ ws.close()
|
|
|
+
|
|
|
+ def on_error(self, error):
|
|
|
+ print(f"WebSocket error: {error}")
|
|
|
+
|
|
|
+ def on_close(self):
|
|
|
+ print("WebSocket connection closed")
|
|
|
+
|
|
|
+ def on_open(self, ws):
|
|
|
+ def run(*args):
|
|
|
+ data = json.dumps(self.gen_params())
|
|
|
+ ws.send(data)
|
|
|
+
|
|
|
+ _thread.start_new_thread(run, ())
|
|
|
+
|
|
|
+ def gen_params(self):
|
|
|
+ """
|
|
|
+ 生成请求参数,包括历史记录和系统信息
|
|
|
+ """
|
|
|
+ texts = []
|
|
|
+ if self.system:
|
|
|
+ texts.append({"role": "system", "content": self.system})
|
|
|
+ texts.extend(self.history)
|
|
|
+ texts.append({"role": "user", "content": self.query})
|
|
|
+ return {
|
|
|
+ "header": {
|
|
|
+ "app_id": self.APPID,
|
|
|
+ "uid": "1234",
|
|
|
+ },
|
|
|
+ "parameter": {
|
|
|
+ "chat": {
|
|
|
+ "domain": self.domain,
|
|
|
+ "temperature": 0.5,
|
|
|
+ "max_tokens": 4096,
|
|
|
+ "auditing": "default",
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "payload": {
|
|
|
+ "message":
|
|
|
+ {
|
|
|
+ "text": texts
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ def start(self):
|
|
|
+ websocket.enableTrace(False)
|
|
|
+ self.ws = websocket.WebSocketApp(self.create_url(),
|
|
|
+ on_message=lambda ws, msg: self.on_message(ws, msg),
|
|
|
+ on_error=lambda msg: self.on_error(msg),
|
|
|
+ on_close=self.on_close,
|
|
|
+ on_open=lambda ws: self.on_open(ws)) # 使用 lambda 来确保 ws 参数传递
|
|
|
+ self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
|
|
+ return self.response
|