123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- import json
- import logging
- import uvicorn as uvicorn
- from fastapi import FastAPI, Request
- from pydantic import BaseModel
- from starlette.middleware.cors import CORSMiddleware
- from starlette.middleware.sessions import SessionMiddleware
- from starlette.responses import JSONResponse
- import pymysql
- from dbutils.pooled_db import PooledDB
- class MySQLPool:
- def __init__(self, host, port, user, password, database, pool_size=5):
- self.pool = PooledDB(
- creator=pymysql, # 指定 MySQL 的 Python 客户端库
- maxconnections=pool_size, # 连接池中最多存在的连接数
- host=host,
- port=port,
- user=user,
- password=password,
- database=database,
- charset='utf8mb4',
- cursorclass=pymysql.cursors.DictCursor,
- )
- def get_connection(self):
- return self.pool.connection()
- def select(self, sql, params=None):
- conn = self.get_connection()
- try:
- with conn.cursor() as cursor:
- cursor.execute(sql, params)
- conn.commit()
- return cursor.fetchall()
- finally:
- conn.close()
- def exec(self, sql, params=None):
- conn = self.get_connection()
- try:
- with conn.cursor() as cursor:
- result = cursor.execute(sql, params)
- conn.commit()
- return result
- finally:
- conn.close()
- db_memos = MySQLPool("101.43.195.48", 3306, "memos", "s3ijTsaH5cciP8s8", "memos")
- from qdrant_client import QdrantClient
- from qdrant_client.models import Distance, VectorParams
- from qdrant_client.models import PointStruct
- from qdrant_client.models import Filter, FieldCondition, MatchValue
- from text2vec import SentenceModel
- class MyVector:
- def __init__(self, ip, port=6333, grpc_port=6334, api_key=None, embedding_path=None):
- self.url = "http://" + ip
- self.port = port
- self.grpc_port = grpc_port
- self.api_key = api_key
- self.client = QdrantClient(
- url=self.url,
- port=self.port,
- grpc_port=self.grpc_port,
- api_key=self.api_key
- )
- self.model = SentenceModel(embedding_path)
- def create_collection(self, collection_name="test_collection", size=16):
- self.client.create_collection(
- # 设置索引的名称
- collection_name=collection_name,
- # 设置索引中输入向量的长度
- # 参数size是数据维度
- # 参数distance是计算的方法,主要有COSINE(余弦),EUCLID(欧氏距离)、DOT(点积),MANHATTAN(曼哈顿距离)
- vectors_config=VectorParams(size=size, distance=Distance.COSINE),
- )
- def delete_collection(self, collection_name="test_collection"):
- self.client.delete_collection(
- collection_name=collection_name,
- )
- def create_document(self, collection_name="test_collection", doc_id=None, vector=None, payload=None):
- if not isinstance(vector, list):
- raise ValueError("Vector must be a list")
- if not all(isinstance(item, float) for item in vector):
- raise ValueError("Vector must be a list of floats")
- operation_info = self.client.upsert(
- collection_name=collection_name,
- wait=True,
- points=[
- PointStruct(id=doc_id, vector=vector, payload=payload),
- ],
- )
- # 返回值
- # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
- print(operation_info)
- def delete_documentById(self, collection_name="test_collection", doc_id=None):
- operation_info = self.client.delete(
- collection_name=collection_name,
- wait=True,
- points_selector=Filter(
- should=[
- FieldCondition(
- key="id",
- match=MatchValue(value=doc_id)
- )
- ]
- ),
- )
- # 返回值
- # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
- print(operation_info)
- def query_data(self, collection_name="test_collection", query_vector=None, limit=10, filter_map=None):
- if filter_map is not None:
- # 遍历 filter_map 的 key和value
- query_filter = Filter(
- must=[
- FieldCondition(
- key=key,
- match=MatchValue(value=value)
- )
- for key, value in filter_map.items()
- ]
- )
- else:
- query_filter = None
- search_result = self.client.search(
- # 设置索引
- collection_name=collection_name,
- # 查询向量
- query_vector=query_vector,
- # 限制返回值的数量
- limit=limit,
- query_filter=query_filter,
- )
- # 返回值
- # [ScoredPoint(id=4, version=0, score=1.362, payload={'city': 'New York'}, vector=None, shard_key=None), ScoredPoint(id=1, version=0, score=1.273, payload={'city': 'Berlin'}, vector=None, shard_key=None), ScoredPoint(id=3, version=0, score=1.208, payload={'city': 'Moscow'}, vector=None, shard_key=None)]
- return search_result
- client = MyVector("101.43.195.48", port=6333, grpc_port=6334, api_key="tianyunperfect123456", embedding_path='/app/memos/models--shibing624--text2vec-base-chinese')
- def getMemoById(id):
- return db_memos.select("select * from memo where id = %s", (id,))
- app = FastAPI(title="memo_webhook name ", description="系统描述 ", version="v 0.0.0")
- app.add_middleware(SessionMiddleware, secret_key='123456hhh')
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*", ],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- @app.exception_handler(Exception)
- def validation_exception_handler(request, exc):
- """请求校验异常捕获; application/json """
- return JSONResponse({'message': "服务器内部错误", 'status_code': 500})
- @app.get('/test')
- async def test():
- return {'msg': 'hello world'}
- class Item(BaseModel):
- creator_id: int
- created_ts: int
- memo: object
- @app.post('/post')
- async def test123(request: Request):
- body = await request.json()
- print(body)
- creator_id = body['creatorId']
- # 如果creatorId 不是1
- if creator_id != 1:
- return {'msg': 'hello world'}
- activity_type = body['activityType']
- is_upsert = False
- is_delete = False
- content_ = body['memo']["content"]
- id_ = body['memo']["id"]
- resource_name = ''
- if activity_type == 'memos.memo.deleted':
- is_delete = True
- else:
- by_id = getMemoById(id_)[0]
- print(by_id)
- if by_id['row_status'] != 'NORMAL':
- is_delete = True
- else:
- is_upsert = True
- resource_name = by_id['resource_name']
- if is_delete:
- print("delete")
- client.delete_documentById(collection_name="memo_collection", doc_id=id_)
- elif is_upsert:
- print("upsert")
- encode = client.model.encode(content_)
- # encode转list
- encode = list(encode)
- # 内部全部转为float
- encode = [float(i) for i in encode]
- client.create_document(collection_name="memo_collection", doc_id=id_, vector=encode,
- payload={"content": content_, 'resource_name': resource_name})
- return {'msg': 'hello world'}
- @app.get('/query')
- async def query_data(request: Request):
- q_ = request.query_params['q']
- print(q_)
- model_encode = client.model.encode(q_)
- encode = list(model_encode)
- encode = [float(i) for i in encode]
- return client.query_data(collection_name="memo_collection", query_vector=encode, limit=10)
- if __name__ == '__main__':
- uvicorn.run(app='memo_webhook:app', host="127.0.0.1", port=8000, reload=True)
- # 启动命令
- # pip3 install -r /tmp/request.txt -i https://pypi.douban.com/simple/
- # pip3 install uvicorn gunicorn fastapi -i https://pypi.douban.com/simple/
- #
- # gunicorn tmp_app:app -b 0.0.0.0:8000 -w 4 -k uvicorn.workers.UvicornWorker -D # 不支持linux运行 -w 进程数
- # uvicorn memo_webhook:app --reload --host 0.0.0.0 --port 9001
|