memo_webhook.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import json
  2. import logging
  3. import uvicorn as uvicorn
  4. from fastapi import FastAPI, Request
  5. from pydantic import BaseModel
  6. from starlette.middleware.cors import CORSMiddleware
  7. from starlette.middleware.sessions import SessionMiddleware
  8. from starlette.responses import JSONResponse
  9. import pymysql
  10. from dbutils.pooled_db import PooledDB
  11. class MySQLPool:
  12. def __init__(self, host, port, user, password, database, pool_size=5):
  13. self.pool = PooledDB(
  14. creator=pymysql, # 指定 MySQL 的 Python 客户端库
  15. maxconnections=pool_size, # 连接池中最多存在的连接数
  16. host=host,
  17. port=port,
  18. user=user,
  19. password=password,
  20. database=database,
  21. charset='utf8mb4',
  22. cursorclass=pymysql.cursors.DictCursor,
  23. )
  24. def get_connection(self):
  25. return self.pool.connection()
  26. def select(self, sql, params=None):
  27. conn = self.get_connection()
  28. try:
  29. with conn.cursor() as cursor:
  30. cursor.execute(sql, params)
  31. conn.commit()
  32. return cursor.fetchall()
  33. finally:
  34. conn.close()
  35. def exec(self, sql, params=None):
  36. conn = self.get_connection()
  37. try:
  38. with conn.cursor() as cursor:
  39. result = cursor.execute(sql, params)
  40. conn.commit()
  41. return result
  42. finally:
  43. conn.close()
  44. db_memos = MySQLPool("101.43.195.48", 3306, "memos", "s3ijTsaH5cciP8s8", "memos")
  45. from qdrant_client import QdrantClient
  46. from qdrant_client.models import Distance, VectorParams
  47. from qdrant_client.models import PointStruct
  48. from qdrant_client.models import Filter, FieldCondition, MatchValue
  49. from text2vec import SentenceModel
  50. class MyVector:
  51. def __init__(self, ip, port=6333, grpc_port=6334, api_key=None, embedding_path=None):
  52. self.url = "http://" + ip
  53. self.port = port
  54. self.grpc_port = grpc_port
  55. self.api_key = api_key
  56. self.client = QdrantClient(
  57. url=self.url,
  58. port=self.port,
  59. grpc_port=self.grpc_port,
  60. api_key=self.api_key
  61. )
  62. self.model = SentenceModel(embedding_path)
  63. def create_collection(self, collection_name="test_collection", size=16):
  64. self.client.create_collection(
  65. # 设置索引的名称
  66. collection_name=collection_name,
  67. # 设置索引中输入向量的长度
  68. # 参数size是数据维度
  69. # 参数distance是计算的方法,主要有COSINE(余弦),EUCLID(欧氏距离)、DOT(点积),MANHATTAN(曼哈顿距离)
  70. vectors_config=VectorParams(size=size, distance=Distance.COSINE),
  71. )
  72. def delete_collection(self, collection_name="test_collection"):
  73. self.client.delete_collection(
  74. collection_name=collection_name,
  75. )
  76. def create_document(self, collection_name="test_collection", doc_id=None, vector=None, payload=None):
  77. if not isinstance(vector, list):
  78. raise ValueError("Vector must be a list")
  79. if not all(isinstance(item, float) for item in vector):
  80. raise ValueError("Vector must be a list of floats")
  81. operation_info = self.client.upsert(
  82. collection_name=collection_name,
  83. wait=True,
  84. points=[
  85. PointStruct(id=doc_id, vector=vector, payload=payload),
  86. ],
  87. )
  88. # 返回值
  89. # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
  90. print(operation_info)
  91. def delete_documentById(self, collection_name="test_collection", doc_id=None):
  92. operation_info = self.client.delete(
  93. collection_name=collection_name,
  94. wait=True,
  95. points_selector=Filter(
  96. should=[
  97. FieldCondition(
  98. key="id",
  99. match=MatchValue(value=doc_id)
  100. )
  101. ]
  102. ),
  103. )
  104. # 返回值
  105. # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
  106. print(operation_info)
  107. def query_data(self, collection_name="test_collection", query_vector=None, limit=10, filter_map=None):
  108. if filter_map is not None:
  109. # 遍历 filter_map 的 key和value
  110. query_filter = Filter(
  111. must=[
  112. FieldCondition(
  113. key=key,
  114. match=MatchValue(value=value)
  115. )
  116. for key, value in filter_map.items()
  117. ]
  118. )
  119. else:
  120. query_filter = None
  121. search_result = self.client.search(
  122. # 设置索引
  123. collection_name=collection_name,
  124. # 查询向量
  125. query_vector=query_vector,
  126. # 限制返回值的数量
  127. limit=limit,
  128. query_filter=query_filter,
  129. )
  130. # 返回值
  131. # [ScoredPoint(id=4, version=0, score=1.362, payload={'city': 'New York'}, vector=None, shard_key=None), ScoredPoint(id=1, version=0, score=1.273, payload={'city': 'Berlin'}, vector=None, shard_key=None), ScoredPoint(id=3, version=0, score=1.208, payload={'city': 'Moscow'}, vector=None, shard_key=None)]
  132. return search_result
  133. client = MyVector("101.43.195.48", port=6333, grpc_port=6334, api_key="tianyunperfect123456", embedding_path='/app/memos/models--shibing624--text2vec-base-chinese')
  134. def getMemoById(id):
  135. return db_memos.select("select * from memo where id = %s", (id,))
  136. app = FastAPI(title="memo_webhook name ", description="系统描述 ", version="v 0.0.0")
  137. app.add_middleware(SessionMiddleware, secret_key='123456hhh')
  138. app.add_middleware(
  139. CORSMiddleware,
  140. allow_origins=["*", ],
  141. allow_credentials=True,
  142. allow_methods=["*"],
  143. allow_headers=["*"],
  144. )
  145. @app.exception_handler(Exception)
  146. def validation_exception_handler(request, exc):
  147. """请求校验异常捕获; application/json """
  148. return JSONResponse({'message': "服务器内部错误", 'status_code': 500})
  149. @app.get('/test')
  150. async def test():
  151. return {'msg': 'hello world'}
  152. class Item(BaseModel):
  153. creator_id: int
  154. created_ts: int
  155. memo: object
  156. @app.post('/post')
  157. async def test123(request: Request):
  158. body = await request.json()
  159. print(body)
  160. creator_id = body['creatorId']
  161. # 如果creatorId 不是1
  162. if creator_id != 1:
  163. return {'msg': 'hello world'}
  164. activity_type = body['activityType']
  165. is_upsert = False
  166. is_delete = False
  167. content_ = body['memo']["content"]
  168. id_ = body['memo']["id"]
  169. resource_name = ''
  170. if activity_type == 'memos.memo.deleted':
  171. is_delete = True
  172. else:
  173. by_id = getMemoById(id_)[0]
  174. print(by_id)
  175. if by_id['row_status'] != 'NORMAL':
  176. is_delete = True
  177. else:
  178. is_upsert = True
  179. resource_name = by_id['resource_name']
  180. if is_delete:
  181. print("delete")
  182. client.delete_documentById(collection_name="memo_collection", doc_id=id_)
  183. elif is_upsert:
  184. print("upsert")
  185. encode = client.model.encode(content_)
  186. # encode转list
  187. encode = list(encode)
  188. # 内部全部转为float
  189. encode = [float(i) for i in encode]
  190. client.create_document(collection_name="memo_collection", doc_id=id_, vector=encode,
  191. payload={"content": content_, 'resource_name': resource_name})
  192. return {'msg': 'hello world'}
  193. @app.get('/query')
  194. async def query_data(request: Request):
  195. q_ = request.query_params['q']
  196. print(q_)
  197. model_encode = client.model.encode(q_)
  198. encode = list(model_encode)
  199. encode = [float(i) for i in encode]
  200. return client.query_data(collection_name="memo_collection", query_vector=encode, limit=10)
  201. if __name__ == '__main__':
  202. uvicorn.run(app='memo_webhook:app', host="127.0.0.1", port=8000, reload=True)
  203. # 启动命令
  204. # pip3 install -r /tmp/request.txt -i https://pypi.douban.com/simple/
  205. # pip3 install uvicorn gunicorn fastapi -i https://pypi.douban.com/simple/
  206. #
  207. # gunicorn tmp_app:app -b 0.0.0.0:8000 -w 4 -k uvicorn.workers.UvicornWorker -D # 不支持linux运行 -w 进程数
  208. # uvicorn memo_webhook:app --reload --host 0.0.0.0 --port 9001