瀏覽代碼

讯飞星火大模型

linhaohong 1 年之前
父節點
當前提交
5782618e3e
共有 1 個文件被更改,包括 124 次插入0 次删除
  1. 124 0
      Object/WsParam/AIChatObject.py

+ 124 - 0
Object/WsParam/AIChatObject.py

@@ -0,0 +1,124 @@
+import websocket
+import json
+import _thread
+import ssl
+from urllib.parse import urlparse, urlencode
+from wsgiref.handlers import format_date_time
+from datetime import datetime
+from time import mktime
+import hashlib
+import hmac
+import base64
+
+"""
+星火语言模型API
+
+@param system: 背景信息
+    - Type: String
+    - Example: system="现在扮演李白,你豪情万丈,狂放不羁;接下来请用李白的口吻和用户对话。"
+@param history: 历史记录示例
+    - Type: List[dict]
+    - Example:
+        history=[
+            {"role": "user", "content": "你给我写一首诗,以兄弟为题不包含兄弟2字,七步内成诗,不行杀了你"},
+            {"role": "assistant", "content": "君若问兄弟,我以诗代言。江山共长天,一脉同根源。风雨共行路,携手笑苍天。豪情万丈志,生死共肩连。肝胆相照映,此心永不偏。纵有千般难,并肩共前缘。人生如梦幻,唯愿共度年。此诗赠君子,兄弟情谊传。"},
+        ]
+@param query: 用户输入字符串
+    - Type: String
+    - Example: query="不好意思,你超过7步了,拖下去砍了。"
+"""
+
+
+class ChatClient:
+    def __init__(self, APPID, APIKey, APISecret, gpt_url, domain, query, history=None, system=None):
+        self.ws = None
+        self.APPID = APPID
+        self.APIKey = APIKey
+        self.APISecret = APISecret
+        self.gpt_url = gpt_url
+        self.domain = domain
+        self.query = query
+        self.system = system
+        # 历史记录应该是一个包含角色和内容的字典列表
+        self.history = history or []
+        self.response = ""
+
+    def create_url(self):
+        host = urlparse(self.gpt_url).netloc
+        path = urlparse(self.gpt_url).path
+        now = datetime.now()
+        date = format_date_time(mktime(now.timetuple()))
+
+        signature_origin = "host: " + host + "\n" + "date: " + date + "\n" + "GET " + path + " HTTP/1.1"
+        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
+                                 digestmod=hashlib.sha256).digest()
+        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
+
+        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
+        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
+
+        v = {"authorization": authorization, "date": date, "host": host}
+        return self.gpt_url + '?' + urlencode(v)
+
+    def on_message(self, ws, message):
+        data = json.loads(message)
+        code = data['header']['code']
+        if code == 0:
+            choices = data["payload"]["choices"]
+            content = choices["text"][0]["content"]
+            self.response += content
+            if choices["status"] == 2:
+                ws.close()
+
+    def on_error(self, error):
+        print(f"WebSocket error: {error}")
+
+    def on_close(self):
+        print("WebSocket connection closed")
+
+    def on_open(self, ws):
+        def run(*args):
+            data = json.dumps(self.gen_params())
+            ws.send(data)
+
+        _thread.start_new_thread(run, ())
+
+    def gen_params(self):
+        """
+        生成请求参数,包括历史记录和系统信息
+        """
+        texts = []
+        if self.system:
+            texts.append({"role": "system", "content": self.system})
+        texts.extend(self.history)
+        texts.append({"role": "user", "content": self.query})
+        return {
+            "header": {
+                "app_id": self.APPID,
+                "uid": "1234",
+            },
+            "parameter": {
+                "chat": {
+                    "domain": self.domain,
+                    "temperature": 0.5,
+                    "max_tokens": 4096,
+                    "auditing": "default",
+                }
+            },
+            "payload": {
+                "message":
+                    {
+                        "text": texts
+                    }
+            }
+        }
+
+    def start(self):
+        websocket.enableTrace(False)
+        self.ws = websocket.WebSocketApp(self.create_url(),
+                                         on_message=lambda ws, msg: self.on_message(ws, msg),
+                                         on_error=lambda msg: self.on_error(msg),
+                                         on_close=self.on_close,
+                                         on_open=lambda ws: self.on_open(ws))  # 使用 lambda 来确保 ws 参数传递
+        self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
+        return self.response