import json import logging import uvicorn as uvicorn from fastapi import FastAPI, Request from pydantic import BaseModel from starlette.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.responses import JSONResponse import pymysql from dbutils.pooled_db import PooledDB class MySQLPool: def __init__(self, host, port, user, password, database, pool_size=5): self.pool = PooledDB( creator=pymysql, # 指定 MySQL 的 Python 客户端库 maxconnections=pool_size, # 连接池中最多存在的连接数 host=host, port=port, user=user, password=password, database=database, charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor, ) def get_connection(self): return self.pool.connection() def select(self, sql, params=None): conn = self.get_connection() try: with conn.cursor() as cursor: cursor.execute(sql, params) conn.commit() return cursor.fetchall() finally: conn.close() def exec(self, sql, params=None): conn = self.get_connection() try: with conn.cursor() as cursor: result = cursor.execute(sql, params) conn.commit() return result finally: conn.close() db_memos = MySQLPool("101.43.195.48", 3306, "memos", "s3ijTsaH5cciP8s8", "memos") from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams from qdrant_client.models import PointStruct from qdrant_client.models import Filter, FieldCondition, MatchValue from text2vec import SentenceModel class MyVector: def __init__(self, ip, port=6333, grpc_port=6334, api_key=None, embedding_path=None): self.url = "http://" + ip self.port = port self.grpc_port = grpc_port self.api_key = api_key self.client = QdrantClient( url=self.url, port=self.port, grpc_port=self.grpc_port, api_key=self.api_key ) self.model = SentenceModel(embedding_path) def create_collection(self, collection_name="test_collection", size=16): self.client.create_collection( # 设置索引的名称 collection_name=collection_name, # 设置索引中输入向量的长度 # 参数size是数据维度 # 参数distance是计算的方法,主要有COSINE(余弦),EUCLID(欧氏距离)、DOT(点积),MANHATTAN(曼哈顿距离) vectors_config=VectorParams(size=size, distance=Distance.COSINE), ) def delete_collection(self, collection_name="test_collection"): self.client.delete_collection( collection_name=collection_name, ) def create_document(self, collection_name="test_collection", doc_id=None, vector=None, payload=None): if not isinstance(vector, list): raise ValueError("Vector must be a list") if not all(isinstance(item, float) for item in vector): raise ValueError("Vector must be a list of floats") operation_info = self.client.upsert( collection_name=collection_name, wait=True, points=[ PointStruct(id=doc_id, vector=vector, payload=payload), ], ) # 返回值 # operation_id=0 status= print(operation_info) def delete_documentById(self, collection_name="test_collection", doc_id=None): operation_info = self.client.delete( collection_name=collection_name, wait=True, points_selector=Filter( should=[ FieldCondition( key="id", match=MatchValue(value=doc_id) ) ] ), ) # 返回值 # operation_id=0 status= print(operation_info) def query_data(self, collection_name="test_collection", query_vector=None, limit=10, filter_map=None): if filter_map is not None: # 遍历 filter_map 的 key和value query_filter = Filter( must=[ FieldCondition( key=key, match=MatchValue(value=value) ) for key, value in filter_map.items() ] ) else: query_filter = None search_result = self.client.search( # 设置索引 collection_name=collection_name, # 查询向量 query_vector=query_vector, # 限制返回值的数量 limit=limit, query_filter=query_filter, ) # 返回值 # [ScoredPoint(id=4, version=0, score=1.362, payload={'city': 'New York'}, vector=None, shard_key=None), ScoredPoint(id=1, version=0, score=1.273, payload={'city': 'Berlin'}, vector=None, shard_key=None), ScoredPoint(id=3, version=0, score=1.208, payload={'city': 'Moscow'}, vector=None, shard_key=None)] return search_result client = MyVector("101.43.195.48", port=6333, grpc_port=6334, api_key="tianyunperfect123456", embedding_path='/app/memos/models--shibing624--text2vec-base-chinese') def getMemoById(id): return db_memos.select("select * from memo where id = %s", (id,)) app = FastAPI(title="memo_webhook name ", description="系统描述 ", version="v 0.0.0") app.add_middleware(SessionMiddleware, secret_key='123456hhh') app.add_middleware( CORSMiddleware, allow_origins=["*", ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.exception_handler(Exception) def validation_exception_handler(request, exc): """请求校验异常捕获; application/json """ return JSONResponse({'message': "服务器内部错误", 'status_code': 500}) @app.get('/test') async def test(): return {'msg': 'hello world'} class Item(BaseModel): creator_id: int created_ts: int memo: object @app.post('/post') async def test123(request: Request): body = await request.json() print(body) creator_id = body['creatorId'] # 如果creatorId 不是1 if creator_id != 1: return {'msg': 'hello world'} activity_type = body['activityType'] is_upsert = False is_delete = False content_ = body['memo']["content"] id_ = body['memo']["id"] resource_name = '' if activity_type == 'memos.memo.deleted': is_delete = True else: by_id = getMemoById(id_)[0] print(by_id) if by_id['row_status'] != 'NORMAL': is_delete = True else: is_upsert = True resource_name = by_id['resource_name'] if is_delete: print("delete") client.delete_documentById(collection_name="memo_collection", doc_id=id_) elif is_upsert: print("upsert") encode = client.model.encode(content_) # encode转list encode = list(encode) # 内部全部转为float encode = [float(i) for i in encode] client.create_document(collection_name="memo_collection", doc_id=id_, vector=encode, payload={"content": content_, 'resource_name': resource_name}) return {'msg': 'hello world'} @app.get('/query') async def query_data(request: Request): q_ = request.query_params['q'] print(q_) model_encode = client.model.encode(q_) encode = list(model_encode) encode = [float(i) for i in encode] return client.query_data(collection_name="memo_collection", query_vector=encode, limit=10) if __name__ == '__main__': uvicorn.run(app='memo_webhook:app', host="127.0.0.1", port=8000, reload=True) # 启动命令 # pip3 install -r /tmp/request.txt -i https://pypi.douban.com/simple/ # pip3 install uvicorn gunicorn fastapi -i https://pypi.douban.com/simple/ # # gunicorn tmp_app:app -b 0.0.0.0:8000 -w 4 -k uvicorn.workers.UvicornWorker -D # 不支持linux运行 -w 进程数 # uvicorn memo_webhook:app --reload --host 0.0.0.0 --port 9001