AIChatObject.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import websocket
  2. import json
  3. import _thread
  4. import ssl
  5. from urllib.parse import urlparse, urlencode
  6. from wsgiref.handlers import format_date_time
  7. from datetime import datetime
  8. from time import mktime
  9. import hashlib
  10. import hmac
  11. import base64
  12. """
  13. 星火语言模型API
  14. @param system: 背景信息
  15. - Type: String
  16. - Example: system="现在扮演李白,你豪情万丈,狂放不羁;接下来请用李白的口吻和用户对话。"
  17. @param history: 历史记录示例
  18. - Type: List[dict]
  19. - Example:
  20. history=[
  21. {"role": "user", "content": "你给我写一首诗,以兄弟为题不包含兄弟2字,七步内成诗,不行杀了你"},
  22. {"role": "assistant", "content": "君若问兄弟,我以诗代言。江山共长天,一脉同根源。风雨共行路,携手笑苍天。豪情万丈志,生死共肩连。肝胆相照映,此心永不偏。纵有千般难,并肩共前缘。人生如梦幻,唯愿共度年。此诗赠君子,兄弟情谊传。"},
  23. ]
  24. @param query: 用户输入字符串
  25. - Type: String
  26. - Example: query="不好意思,你超过7步了,拖下去砍了。"
  27. """
  28. class ChatClient:
  29. def __init__(self, APPID, APIKey, APISecret, gpt_url, domain, query, history=None, system=None):
  30. self.ws = None
  31. self.APPID = APPID
  32. self.APIKey = APIKey
  33. self.APISecret = APISecret
  34. self.gpt_url = gpt_url
  35. self.domain = domain
  36. self.query = query
  37. self.system = system
  38. # 历史记录应该是一个包含角色和内容的字典列表
  39. self.history = history or []
  40. self.response = ""
  41. def create_url(self):
  42. host = urlparse(self.gpt_url).netloc
  43. path = urlparse(self.gpt_url).path
  44. now = datetime.now()
  45. date = format_date_time(mktime(now.timetuple()))
  46. signature_origin = "host: " + host + "\n" + "date: " + date + "\n" + "GET " + path + " HTTP/1.1"
  47. signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
  48. digestmod=hashlib.sha256).digest()
  49. signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
  50. authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  51. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  52. v = {"authorization": authorization, "date": date, "host": host}
  53. return self.gpt_url + '?' + urlencode(v)
  54. def on_message(self, ws, message):
  55. data = json.loads(message)
  56. code = data['header']['code']
  57. if code == 0:
  58. choices = data["payload"]["choices"]
  59. content = choices["text"][0]["content"]
  60. self.response += content
  61. if choices["status"] == 2:
  62. ws.close()
  63. def on_error(self, ws, error):
  64. print(f"WebSocket error: {error}")
  65. def on_close(self, ws):
  66. print("WebSocket connection closed")
  67. def on_open(self, ws):
  68. def run(*args):
  69. data = json.dumps(self.gen_params())
  70. ws.send(data)
  71. _thread.start_new_thread(run, ())
  72. def gen_params(self):
  73. """
  74. 生成请求参数,包括历史记录和系统信息
  75. """
  76. texts = []
  77. if self.system:
  78. texts.append({"role": "system", "content": self.system})
  79. texts.extend(self.history)
  80. texts.append({"role": "user", "content": self.query})
  81. return {
  82. "header": {
  83. "app_id": self.APPID,
  84. "uid": "1234",
  85. },
  86. "parameter": {
  87. "chat": {
  88. "domain": self.domain,
  89. "temperature": 0.5,
  90. "max_tokens": 4096,
  91. "auditing": "default",
  92. }
  93. },
  94. "payload": {
  95. "message":
  96. {
  97. "text": texts
  98. }
  99. }
  100. }
  101. def start(self):
  102. websocket.enableTrace(True)
  103. self.ws = websocket.WebSocketApp(self.create_url(),
  104. on_message=lambda ws, msg: self.on_message(ws, msg),
  105. on_error=lambda ws, error: self.on_error(ws, error),
  106. on_close=lambda ws: self.on_close(ws),
  107. on_open=lambda ws: self.on_open(ws))
  108. self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  109. return self.response
  110. if __name__ == "__main__":
  111. appid = "fcff8f4b"
  112. api_secret = "ZTU3NWMyNTI1MTI4NTU5ZGUxMDZhNmQ5"
  113. api_key = "037571e7285e64e8dc321fa5b937fea2"
  114. gpt_url = "wss://spark-api.xf-yun.com/v3.5/chat"
  115. domain = "generalv3.5"
  116. system = "现在扮演李白,你豪情万丈,狂放不羁;接下来请用李白的口吻和用户对话。"
  117. query = "不好意思,你超过7步了,拖下去砍了。"
  118. history = [
  119. {"role": "user", "content": "你给我写一首诗,以兄弟为题不包含兄弟2字,七步内成诗,不行杀了你"},
  120. {"role": "assistant",
  121. "content": "君若问兄弟,我以诗代言。江山共长天,一脉同根源。风雨共行路,携手笑苍天。豪情万丈志,生死共肩连。肝胆相照映,此心永不偏。纵有千般难,并肩共前缘。人生如梦幻,唯愿共度年。此诗赠君子,兄弟情谊传。"},
  122. ]
  123. chat = ChatClient(appid, api_key, api_secret, gpt_url, domain, query, history, system)
  124. response = chat.start()
  125. print(response)