|
@@ -0,0 +1,248 @@
|
|
|
|
+import json
|
|
|
|
+import logging
|
|
|
|
+
|
|
|
|
+import uvicorn as uvicorn
|
|
|
|
+from fastapi import FastAPI, Request
|
|
|
|
+from pydantic import BaseModel
|
|
|
|
+from starlette.middleware.cors import CORSMiddleware
|
|
|
|
+from starlette.middleware.sessions import SessionMiddleware
|
|
|
|
+from starlette.responses import JSONResponse
|
|
|
|
+import pymysql
|
|
|
|
+from dbutils.pooled_db import PooledDB
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class MySQLPool:
|
|
|
|
+ def __init__(self, host, port, user, password, database, pool_size=5):
|
|
|
|
+ self.pool = PooledDB(
|
|
|
|
+ creator=pymysql, # 指定 MySQL 的 Python 客户端库
|
|
|
|
+ maxconnections=pool_size, # 连接池中最多存在的连接数
|
|
|
|
+ host=host,
|
|
|
|
+ port=port,
|
|
|
|
+ user=user,
|
|
|
|
+ password=password,
|
|
|
|
+ database=database,
|
|
|
|
+ charset='utf8mb4',
|
|
|
|
+ cursorclass=pymysql.cursors.DictCursor,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def get_connection(self):
|
|
|
|
+ return self.pool.connection()
|
|
|
|
+
|
|
|
|
+ def select(self, sql, params=None):
|
|
|
|
+ conn = self.get_connection()
|
|
|
|
+ try:
|
|
|
|
+ with conn.cursor() as cursor:
|
|
|
|
+ cursor.execute(sql, params)
|
|
|
|
+ conn.commit()
|
|
|
|
+ return cursor.fetchall()
|
|
|
|
+ finally:
|
|
|
|
+ conn.close()
|
|
|
|
+
|
|
|
|
+ def exec(self, sql, params=None):
|
|
|
|
+ conn = self.get_connection()
|
|
|
|
+ try:
|
|
|
|
+ with conn.cursor() as cursor:
|
|
|
|
+ result = cursor.execute(sql, params)
|
|
|
|
+ conn.commit()
|
|
|
|
+ return result
|
|
|
|
+ finally:
|
|
|
|
+ conn.close()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+db_memos = MySQLPool("101.43.195.48", 3306, "memos", "s3ijTsaH5cciP8s8", "memos")
|
|
|
|
+
|
|
|
|
+from qdrant_client import QdrantClient
|
|
|
|
+from qdrant_client.models import Distance, VectorParams
|
|
|
|
+
|
|
|
|
+from qdrant_client.models import PointStruct
|
|
|
|
+
|
|
|
|
+from qdrant_client.models import Filter, FieldCondition, MatchValue
|
|
|
|
+from text2vec import SentenceModel
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class MyVector:
|
|
|
|
+ def __init__(self, ip, port=6333, grpc_port=6334, api_key=None, embedding_path=None):
|
|
|
|
+ self.url = "http://" + ip
|
|
|
|
+ self.port = port
|
|
|
|
+ self.grpc_port = grpc_port
|
|
|
|
+ self.api_key = api_key
|
|
|
|
+ self.client = QdrantClient(
|
|
|
|
+ url=self.url,
|
|
|
|
+ port=self.port,
|
|
|
|
+ grpc_port=self.grpc_port,
|
|
|
|
+ api_key=self.api_key
|
|
|
|
+ )
|
|
|
|
+ self.model = SentenceModel(embedding_path)
|
|
|
|
+
|
|
|
|
+ def create_collection(self, collection_name="test_collection", size=16):
|
|
|
|
+ self.client.create_collection(
|
|
|
|
+ # 设置索引的名称
|
|
|
|
+ collection_name=collection_name,
|
|
|
|
+ # 设置索引中输入向量的长度
|
|
|
|
+ # 参数size是数据维度
|
|
|
|
+ # 参数distance是计算的方法,主要有COSINE(余弦),EUCLID(欧氏距离)、DOT(点积),MANHATTAN(曼哈顿距离)
|
|
|
|
+ vectors_config=VectorParams(size=size, distance=Distance.COSINE),
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def delete_collection(self, collection_name="test_collection"):
|
|
|
|
+ self.client.delete_collection(
|
|
|
|
+ collection_name=collection_name,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def create_document(self, collection_name="test_collection", doc_id=None, vector=None, payload=None):
|
|
|
|
+ if not isinstance(vector, list):
|
|
|
|
+ raise ValueError("Vector must be a list")
|
|
|
|
+ if not all(isinstance(item, float) for item in vector):
|
|
|
|
+ raise ValueError("Vector must be a list of floats")
|
|
|
|
+ operation_info = self.client.upsert(
|
|
|
|
+ collection_name=collection_name,
|
|
|
|
+ wait=True,
|
|
|
|
+ points=[
|
|
|
|
+ PointStruct(id=doc_id, vector=vector, payload=payload),
|
|
|
|
+ ],
|
|
|
|
+ )
|
|
|
|
+ # 返回值
|
|
|
|
+ # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
|
|
|
|
+ print(operation_info)
|
|
|
|
+
|
|
|
|
+ def delete_documentById(self, collection_name="test_collection", doc_id=None):
|
|
|
|
+ operation_info = self.client.delete(
|
|
|
|
+ collection_name=collection_name,
|
|
|
|
+ wait=True,
|
|
|
|
+ points_selector=Filter(
|
|
|
|
+ should=[
|
|
|
|
+ FieldCondition(
|
|
|
|
+ key="id",
|
|
|
|
+ match=MatchValue(value=doc_id)
|
|
|
|
+ )
|
|
|
|
+ ]
|
|
|
|
+ ),
|
|
|
|
+ )
|
|
|
|
+ # 返回值
|
|
|
|
+ # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
|
|
|
|
+ print(operation_info)
|
|
|
|
+
|
|
|
|
+ def query_data(self, collection_name="test_collection", query_vector=None, limit=10, filter_map=None):
|
|
|
|
+ if filter_map is not None:
|
|
|
|
+ # 遍历 filter_map 的 key和value
|
|
|
|
+ query_filter = Filter(
|
|
|
|
+ must=[
|
|
|
|
+ FieldCondition(
|
|
|
|
+ key=key,
|
|
|
|
+ match=MatchValue(value=value)
|
|
|
|
+ )
|
|
|
|
+ for key, value in filter_map.items()
|
|
|
|
+ ]
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ query_filter = None
|
|
|
|
+
|
|
|
|
+ search_result = self.client.search(
|
|
|
|
+ # 设置索引
|
|
|
|
+ collection_name=collection_name,
|
|
|
|
+ # 查询向量
|
|
|
|
+ query_vector=query_vector,
|
|
|
|
+ # 限制返回值的数量
|
|
|
|
+ limit=limit,
|
|
|
|
+ query_filter=query_filter,
|
|
|
|
+ )
|
|
|
|
+ # 返回值
|
|
|
|
+ # [ScoredPoint(id=4, version=0, score=1.362, payload={'city': 'New York'}, vector=None, shard_key=None), ScoredPoint(id=1, version=0, score=1.273, payload={'city': 'Berlin'}, vector=None, shard_key=None), ScoredPoint(id=3, version=0, score=1.208, payload={'city': 'Moscow'}, vector=None, shard_key=None)]
|
|
|
|
+ return search_result
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+client = MyVector("101.43.195.48", port=6333, grpc_port=6334, api_key="tianyunperfect123456", embedding_path='/app/memos/models--shibing624--text2vec-base-chinese')
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def getMemoById(id):
|
|
|
|
+ return db_memos.select("select * from memo where id = %s", (id,))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+app = FastAPI(title="memo_webhook name ", description="系统描述 ", version="v 0.0.0")
|
|
|
|
+app.add_middleware(SessionMiddleware, secret_key='123456hhh')
|
|
|
|
+app.add_middleware(
|
|
|
|
+ CORSMiddleware,
|
|
|
|
+ allow_origins=["*", ],
|
|
|
|
+ allow_credentials=True,
|
|
|
|
+ allow_methods=["*"],
|
|
|
|
+ allow_headers=["*"],
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.exception_handler(Exception)
|
|
|
|
+def validation_exception_handler(request, exc):
|
|
|
|
+ """请求校验异常捕获; application/json """
|
|
|
|
+ return JSONResponse({'message': "服务器内部错误", 'status_code': 500})
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.get('/test')
|
|
|
|
+async def test():
|
|
|
|
+ return {'msg': 'hello world'}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class Item(BaseModel):
|
|
|
|
+ creator_id: int
|
|
|
|
+ created_ts: int
|
|
|
|
+ memo: object
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post('/post')
|
|
|
|
+async def test123(request: Request):
|
|
|
|
+ body = await request.json()
|
|
|
|
+ print(body)
|
|
|
|
+ creator_id = body['creatorId']
|
|
|
|
+ # 如果creatorId 不是1
|
|
|
|
+ if creator_id != 1:
|
|
|
|
+ return {'msg': 'hello world'}
|
|
|
|
+
|
|
|
|
+ activity_type = body['activityType']
|
|
|
|
+ is_upsert = False
|
|
|
|
+ is_delete = False
|
|
|
|
+ content_ = body['memo']["content"]
|
|
|
|
+ id_ = body['memo']["id"]
|
|
|
|
+ resource_name = ''
|
|
|
|
+ if activity_type == 'memos.memo.deleted':
|
|
|
|
+ is_delete = True
|
|
|
|
+ else:
|
|
|
|
+ by_id = getMemoById(id_)[0]
|
|
|
|
+ print(by_id)
|
|
|
|
+ if by_id['row_status'] != 'NORMAL':
|
|
|
|
+ is_delete = True
|
|
|
|
+ else:
|
|
|
|
+ is_upsert = True
|
|
|
|
+ resource_name = by_id['resource_name']
|
|
|
|
+ if is_delete:
|
|
|
|
+ print("delete")
|
|
|
|
+ client.delete_documentById(collection_name="memo_collection", doc_id=id_)
|
|
|
|
+ elif is_upsert:
|
|
|
|
+ print("upsert")
|
|
|
|
+ encode = client.model.encode(content_)
|
|
|
|
+ # encode转list
|
|
|
|
+ encode = list(encode)
|
|
|
|
+ # 内部全部转为float
|
|
|
|
+ encode = [float(i) for i in encode]
|
|
|
|
+ client.create_document(collection_name="memo_collection", doc_id=id_, vector=encode,
|
|
|
|
+ payload={"content": content_, 'resource_name': resource_name})
|
|
|
|
+
|
|
|
|
+ return {'msg': 'hello world'}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.get('/query')
|
|
|
|
+async def query_data(request: Request):
|
|
|
|
+ q_ = request.query_params['q']
|
|
|
|
+ print(q_)
|
|
|
|
+ model_encode = client.model.encode(q_)
|
|
|
|
+ encode = list(model_encode)
|
|
|
|
+ encode = [float(i) for i in encode]
|
|
|
|
+ return client.query_data(collection_name="memo_collection", query_vector=encode, limit=10)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
+ uvicorn.run(app='memo_webhook:app', host="127.0.0.1", port=8000, reload=True)
|
|
|
|
+
|
|
|
|
+# 启动命令
|
|
|
|
+# pip3 install -r /tmp/request.txt -i https://pypi.douban.com/simple/
|
|
|
|
+# pip3 install uvicorn gunicorn fastapi -i https://pypi.douban.com/simple/
|
|
|
|
+#
|
|
|
|
+# gunicorn tmp_app:app -b 0.0.0.0:8000 -w 4 -k uvicorn.workers.UvicornWorker -D # 不支持linux运行 -w 进程数
|
|
|
|
+# uvicorn memo_webhook:app --reload --host 0.0.0.0 --port 9001
|