tianyun 9 months ago
parent
commit
7fca5993d1
12 changed files with 536 additions and 1 deletions
  1. 2 0
      .gitignore
  2. 39 0
      tmp10.py
  3. 15 0
      tmp7.py
  4. 13 0
      tmp8.py
  5. 23 0
      tmp9.py
  6. 18 1
      tmp_app.py
  7. 33 0
      vector_test/init-db.py
  8. 248 0
      vector_test/memo_webhook.py
  9. 7 0
      vector_test/tmp01.py
  10. 16 0
      vector_test/tmp02.py
  11. 7 0
      vector_test/tmp03.py
  12. 115 0
      vector_test/vector_util.py

+ 2 - 0
.gitignore

@@ -4,3 +4,5 @@ __pycache__
 .DS_Store
 .DS_Store
 /catboost_info/
 /catboost_info/
 *.csv
 *.csv
+/vector_test/models--shibing624--text2vec-base-chinese/
+/copyD/nohup.out

+ 39 - 0
tmp10.py

@@ -0,0 +1,39 @@
+from qdrant_client import QdrantClient
+
+from qdrant_client.models import Distance, VectorParams
+from qdrant_client.models import PointStruct
+
+client = QdrantClient(path="my_qdrant")
+
+
+def create_collection():
+    client.create_collection(
+        collection_name="test_collection",
+        vectors_config=VectorParams(size=4, distance=Distance.DOT),
+    )
+
+
+def add_vectors():
+    operation_info = client.upsert(
+        collection_name="test_collection",
+        wait=True,
+        points=[
+            PointStruct(id=1, vector=[0.05, 0.61, 0.76, 0.74], payload={"city": "Berlin"}),
+            PointStruct(id=2, vector=[0.19, 0.81, 0.75, 0.11], payload={"city": "London"}),
+            PointStruct(id=3, vector=[0.36, 0.55, 0.47, 0.94], payload={"city": "Moscow"}),
+            PointStruct(id=4, vector=[0.18, 0.01, 0.85, 0.80], payload={"city": "New York"}),
+            PointStruct(id=5, vector=[0.24, 0.18, 0.22, 0.44], payload={"city": "Beijing"}),
+            PointStruct(id=6, vector=[0.35, 0.08, 0.11, 0.44], payload={"city": "Mumbai"}),
+        ],
+    )
+
+    print(operation_info)
+
+
+def query():
+    search_result = client.search(
+        collection_name="test_collection", query_vector=[0.2, 0.1, 0.9, 0.7], limit=3, with_vectors=True
+    )
+    print(search_result)
+
+# create_collection()

+ 15 - 0
tmp7.py

@@ -0,0 +1,15 @@
+import random
+
+from visualdl import LogWriter
+
+logdir = "./log/scalar_test/train"
+if __name__ == '__main__':
+    # 随机一个小数
+    value = [i * random.random() / 1000.0 for i in range(1000)]
+    # 初始化一个记录器
+    with LogWriter(logdir=logdir) as writer:
+        for step in range(1000):
+            # 向记录器添加一个tag为`acc`的数据
+            writer.add_scalar(tag="acc", step=step, value=value[step])
+            # 向记录器添加一个tag为`loss`的数据
+            writer.add_scalar(tag="loss", step=step, value=1 / (value[step] + 1))

+ 13 - 0
tmp8.py

@@ -0,0 +1,13 @@
+from visualdl.server import app
+
+logdir = "./log/scalar_test/train"
+
+if __name__ == '__main__':
+    app.run(logdir=logdir,
+            host="127.0.0.1",
+            port=8080,
+            cache_timeout=20,
+            language=None,
+            public_path=None,
+            api_only=False,
+            open_browser=False)

+ 23 - 0
tmp9.py

@@ -0,0 +1,23 @@
+import requests
+
+headers = {
+    'accept': '*/*',
+    'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8',
+    'authorization': 'bearer eyJhbGciOiJIUzI1NiIsImtpZCI6InYxIiwidHlwIjoiSldUIn0.eyJuYW1lIjoidGlhbnl1bnBlcmZlY3QiLCJpc3MiOiJtZW1vcyIsInN1YiI6IjEiLCJhdWQiOlsidXNlci5hY2Nlc3MtdG9rZW4iXSwiaWF0IjoxNzA5MTc5NTUyfQ.LFxWB4efya1sL7VoJ42xpXxbAip-udT_Kx2OwZ8Y3-E',
+    'cache-control': 'no-cache',
+    'content-type': 'application/json',
+    'origin': 'https://web.tianyunperfect.cn',
+    'pragma': 'no-cache',
+    'priority': 'u=1, i',
+    'referer': 'https://web.tianyunperfect.cn/simple/memos.html',
+    'sec-ch-ua': '"Chromium";v="128", "Not;A=Brand";v="24", "Google Chrome";v="128"',
+    'sec-ch-ua-mobile': '?0',
+    'sec-ch-ua-platform': '"macOS"',
+    'sec-fetch-dest': 'empty',
+    'sec-fetch-mode': 'cors',
+    'sec-fetch-site': 'same-site',
+    'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36',
+}
+
+response = requests.get('https://memos.tianyunperfect.cn/api/v1/tag', headers=headers, verify=False)
+print(response)

+ 18 - 1
tmp_app.py

@@ -1,5 +1,5 @@
 import uvicorn as uvicorn
 import uvicorn as uvicorn
-from fastapi import FastAPI
+from fastapi import FastAPI, Request
 from pydantic.main import BaseModel
 from pydantic.main import BaseModel
 from starlette.middleware.cors import CORSMiddleware
 from starlette.middleware.cors import CORSMiddleware
 from starlette.middleware.sessions import SessionMiddleware
 from starlette.middleware.sessions import SessionMiddleware
@@ -22,6 +22,22 @@ def validation_exception_handler(request, exc):
     return JSONResponse({'message': "服务器内部错误", 'status_code': 500})
     return JSONResponse({'message': "服务器内部错误", 'status_code': 500})
 
 
 
 
+@app.post('/post')
+async def test123(request: Request):
+    # 获取所有参数
+    params = dict(request.query_params)
+    print(f"Query parameters: {params}")
+
+    # 获取请求头
+    headers = dict(request.headers)
+    print(f"Headers: {headers}")
+
+    # 获取请求体
+    body = await request.json()  # 如果是post form,请使用request.form()
+    print(f"Body: {body}")
+    return {'msg': 'hello world'}
+
+
 @app.get("/query")
 @app.get("/query")
 def query(uid: str):
 def query(uid: str):
     msg = f'uid为{uid}'
     msg = f'uid为{uid}'
@@ -33,6 +49,7 @@ def query(uid: str):
     msg = f'uid为{uid}'
     msg = f'uid为{uid}'
     return {'success': True, 'msg': msg}
     return {'success': True, 'msg': msg}
 
 
+
 @app.get("/test123")
 @app.get("/test123")
 def test123():
 def test123():
     # 休眠20秒
     # 休眠20秒

+ 33 - 0
vector_test/init-db.py

@@ -0,0 +1,33 @@
+import pymysql
+
+from vector_test.vector_util import MyVector
+
+
+class Database:
+    def __init__(self, host, user, password, database):
+        self.connection = pymysql.connect(
+            host=host,
+            user=user,
+            password=password,
+            database=database
+        )
+
+    def query(self, sql):
+        with self.connection.cursor() as cursor:
+            cursor.execute(sql)
+            result = cursor.fetchall()
+        return result
+
+    def execute(self, sql):
+        with self.connection.cursor() as cursor:
+            cursor.execute(sql)
+        self.connection.commit()
+
+
+db = Database('www.tianyunperfect.cn', 'memos', 's3ijTsaH5cciP8s8', 'memos')
+client = MyVector("101.43.195.48", port=6333, grpc_port=6334, api_key="tianyunperfect123456", embedding_path='/Users/alvin/IdeaProjects/python-base/vector_test/models--shibing624--text2vec-base-chinese')
+# 查询
+result = db.query("SELECT id,content,resource_name FROM memo where row_status = 'NORMAL' and creator_id=1")
+# 遍历result
+for row in result:
+    client.create_document(collection_name="memo_collection", doc_id=row[0], vector=client.model.encode(row[1]), payload={"content": row[1], "resource_name": row[2]})

+ 248 - 0
vector_test/memo_webhook.py

@@ -0,0 +1,248 @@
+import json
+import logging
+
+import uvicorn as uvicorn
+from fastapi import FastAPI, Request
+from pydantic import BaseModel
+from starlette.middleware.cors import CORSMiddleware
+from starlette.middleware.sessions import SessionMiddleware
+from starlette.responses import JSONResponse
+import pymysql
+from dbutils.pooled_db import PooledDB
+
+
+class MySQLPool:
+    def __init__(self, host, port, user, password, database, pool_size=5):
+        self.pool = PooledDB(
+            creator=pymysql,  # 指定 MySQL 的 Python 客户端库
+            maxconnections=pool_size,  # 连接池中最多存在的连接数
+            host=host,
+            port=port,
+            user=user,
+            password=password,
+            database=database,
+            charset='utf8mb4',
+            cursorclass=pymysql.cursors.DictCursor,
+        )
+
+    def get_connection(self):
+        return self.pool.connection()
+
+    def select(self, sql, params=None):
+        conn = self.get_connection()
+        try:
+            with conn.cursor() as cursor:
+                cursor.execute(sql, params)
+                conn.commit()
+                return cursor.fetchall()
+        finally:
+            conn.close()
+
+    def exec(self, sql, params=None):
+        conn = self.get_connection()
+        try:
+            with conn.cursor() as cursor:
+                result = cursor.execute(sql, params)
+                conn.commit()
+                return result
+        finally:
+            conn.close()
+
+
+db_memos = MySQLPool("101.43.195.48", 3306, "memos", "s3ijTsaH5cciP8s8", "memos")
+
+from qdrant_client import QdrantClient
+from qdrant_client.models import Distance, VectorParams
+
+from qdrant_client.models import PointStruct
+
+from qdrant_client.models import Filter, FieldCondition, MatchValue
+from text2vec import SentenceModel
+
+
+class MyVector:
+    def __init__(self, ip, port=6333, grpc_port=6334, api_key=None, embedding_path=None):
+        self.url = "http://" + ip
+        self.port = port
+        self.grpc_port = grpc_port
+        self.api_key = api_key
+        self.client = QdrantClient(
+            url=self.url,
+            port=self.port,
+            grpc_port=self.grpc_port,
+            api_key=self.api_key
+        )
+        self.model = SentenceModel(embedding_path)
+
+    def create_collection(self, collection_name="test_collection", size=16):
+        self.client.create_collection(
+            # 设置索引的名称
+            collection_name=collection_name,
+            # 设置索引中输入向量的长度
+            # 参数size是数据维度
+            # 参数distance是计算的方法,主要有COSINE(余弦),EUCLID(欧氏距离)、DOT(点积),MANHATTAN(曼哈顿距离)
+            vectors_config=VectorParams(size=size, distance=Distance.COSINE),
+        )
+
+    def delete_collection(self, collection_name="test_collection"):
+        self.client.delete_collection(
+            collection_name=collection_name,
+        )
+
+    def create_document(self, collection_name="test_collection", doc_id=None, vector=None, payload=None):
+        if not isinstance(vector, list):
+            raise ValueError("Vector must be a list")
+        if not all(isinstance(item, float) for item in vector):
+            raise ValueError("Vector must be a list of floats")
+        operation_info = self.client.upsert(
+            collection_name=collection_name,
+            wait=True,
+            points=[
+                PointStruct(id=doc_id, vector=vector, payload=payload),
+            ],
+        )
+        # 返回值
+        # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
+        print(operation_info)
+
+    def delete_documentById(self, collection_name="test_collection", doc_id=None):
+        operation_info = self.client.delete(
+            collection_name=collection_name,
+            wait=True,
+            points_selector=Filter(
+                should=[
+                    FieldCondition(
+                        key="id",
+                        match=MatchValue(value=doc_id)
+                    )
+                ]
+            ),
+        )
+        # 返回值
+        # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
+        print(operation_info)
+
+    def query_data(self, collection_name="test_collection", query_vector=None, limit=10, filter_map=None):
+        if filter_map is not None:
+            # 遍历 filter_map 的 key和value
+            query_filter = Filter(
+                must=[
+                    FieldCondition(
+                        key=key,
+                        match=MatchValue(value=value)
+                    )
+                    for key, value in filter_map.items()
+                ]
+            )
+        else:
+            query_filter = None
+
+        search_result = self.client.search(
+            # 设置索引
+            collection_name=collection_name,
+            # 查询向量
+            query_vector=query_vector,
+            # 限制返回值的数量
+            limit=limit,
+            query_filter=query_filter,
+        )
+        # 返回值
+        # [ScoredPoint(id=4, version=0, score=1.362, payload={'city': 'New York'}, vector=None, shard_key=None), ScoredPoint(id=1, version=0, score=1.273, payload={'city': 'Berlin'}, vector=None, shard_key=None), ScoredPoint(id=3, version=0, score=1.208, payload={'city': 'Moscow'}, vector=None, shard_key=None)]
+        return search_result
+
+
+client = MyVector("101.43.195.48", port=6333, grpc_port=6334, api_key="tianyunperfect123456", embedding_path='/app/memos/models--shibing624--text2vec-base-chinese')
+
+
+def getMemoById(id):
+    return db_memos.select("select * from memo where id = %s", (id,))
+
+
+app = FastAPI(title="memo_webhook name ", description="系统描述 ", version="v 0.0.0")
+app.add_middleware(SessionMiddleware, secret_key='123456hhh')
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*", ],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+
+@app.exception_handler(Exception)
+def validation_exception_handler(request, exc):
+    """请求校验异常捕获; application/json """
+    return JSONResponse({'message': "服务器内部错误", 'status_code': 500})
+
+
+@app.get('/test')
+async def test():
+    return {'msg': 'hello world'}
+
+
+class Item(BaseModel):
+    creator_id: int
+    created_ts: int
+    memo: object
+
+
+@app.post('/post')
+async def test123(request: Request):
+    body = await request.json()
+    print(body)
+    creator_id = body['creatorId']
+    # 如果creatorId 不是1
+    if creator_id != 1:
+        return {'msg': 'hello world'}
+
+    activity_type = body['activityType']
+    is_upsert = False
+    is_delete = False
+    content_ = body['memo']["content"]
+    id_ = body['memo']["id"]
+    resource_name = ''
+    if activity_type == 'memos.memo.deleted':
+        is_delete = True
+    else:
+        by_id = getMemoById(id_)[0]
+        print(by_id)
+        if by_id['row_status'] != 'NORMAL':
+            is_delete = True
+        else:
+            is_upsert = True
+            resource_name = by_id['resource_name']
+    if is_delete:
+        print("delete")
+        client.delete_documentById(collection_name="memo_collection", doc_id=id_)
+    elif is_upsert:
+        print("upsert")
+        encode = client.model.encode(content_)
+        # encode转list
+        encode = list(encode)
+        # 内部全部转为float
+        encode = [float(i) for i in encode]
+        client.create_document(collection_name="memo_collection", doc_id=id_, vector=encode,
+                               payload={"content": content_, 'resource_name': resource_name})
+
+    return {'msg': 'hello world'}
+
+
+@app.get('/query')
+async def query_data(request: Request):
+    q_ = request.query_params['q']
+    print(q_)
+    model_encode = client.model.encode(q_)
+    encode = list(model_encode)
+    encode = [float(i) for i in encode]
+    return client.query_data(collection_name="memo_collection", query_vector=encode, limit=10)
+
+
+if __name__ == '__main__':
+    uvicorn.run(app='memo_webhook:app', host="127.0.0.1", port=8000, reload=True)
+
+# 启动命令
+# pip3 install -r /tmp/request.txt -i https://pypi.douban.com/simple/
+# pip3 install uvicorn gunicorn fastapi -i https://pypi.douban.com/simple/
+#
+# gunicorn tmp_app:app -b 0.0.0.0:8000 -w 4 -k uvicorn.workers.UvicornWorker -D  # 不支持linux运行 -w 进程数
+# uvicorn memo_webhook:app --reload --host 0.0.0.0 --port 9001

+ 7 - 0
vector_test/tmp01.py

@@ -0,0 +1,7 @@
+import numpy as np
+from text2vec import SentenceModel
+
+model = SentenceModel('models--shibing624--text2vec-base-chinese')
+
+sentence_embeddings = model.encode('今天天气不错')
+print(sentence_embeddings)

+ 16 - 0
vector_test/tmp02.py

@@ -0,0 +1,16 @@
+from text2vec import Similarity, EncoderType
+
+sim_model = Similarity(model_name_or_path='shibing624/text2vec-base-chinese',
+                       encoder_type=EncoderType.CLS)
+
+
+def ai_text(sentence1, sentence2):
+    score = sim_model.get_score(sentence1, sentence2)
+    print("{} \t\t {} \t\t Score: {:.4f}".format(sentence1, sentence2, score))
+
+    return score
+
+
+if __name__ == '__main__':
+    ai_text("有战争的梦", "#梦记\n梅林,魔法的世界,一个人打巨狼")
+    ai_text("有战争的梦", "2023年08月04日 战争 逃离\n梦里发生了战争,拿着必要物品,背着迁徙,和金津,人群里还有宇佳。\n我被炸弹炸了好几回,着急赶路,似乎要去山上躲避战乱生存。\n #梦记")

+ 7 - 0
vector_test/tmp03.py

@@ -0,0 +1,7 @@
+# jieba分词
+import jieba
+
+# 今天的天气真不错
+sentence = "今天的的天气真不错,你觉得呢?"
+words = jieba.cut(sentence)
+print(" ".join(words))

+ 115 - 0
vector_test/vector_util.py

@@ -0,0 +1,115 @@
+from qdrant_client import QdrantClient
+from qdrant_client.models import Distance, VectorParams
+
+from qdrant_client.models import PointStruct
+
+from qdrant_client.models import Filter, FieldCondition, MatchValue
+from text2vec import SentenceModel
+
+
+class MyVector:
+    def __init__(self, ip, port=6333, grpc_port=6334, api_key=None, embedding_path=None):
+        self.url = "http://" + ip
+        self.port = port
+        self.grpc_port = grpc_port
+        self.api_key = api_key
+        self.client = QdrantClient(
+            url=self.url,
+            port=self.port,
+            grpc_port=self.grpc_port,
+            api_key=self.api_key
+        )
+        self.model = SentenceModel(embedding_path)
+
+    def create_collection(self, collection_name="test_collection", size=16):
+        self.client.create_collection(
+            # 设置索引的名称
+            collection_name=collection_name,
+            # 设置索引中输入向量的长度
+            # 参数size是数据维度
+            # 参数distance是计算的方法,主要有COSINE(余弦),EUCLID(欧氏距离)、DOT(点积),MANHATTAN(曼哈顿距离)
+            vectors_config=VectorParams(size=size, distance=Distance.COSINE),
+        )
+
+    def delete_collection(self, collection_name="test_collection"):
+        self.client.delete_collection(
+            collection_name=collection_name,
+        )
+
+    def create_document(self, collection_name="test_collection", doc_id=None, vector=None, payload=None):
+        operation_info = self.client.upsert(
+            collection_name=collection_name,
+            wait=True,
+            points=[
+                PointStruct(id=doc_id, vector=vector, payload=payload),
+            ],
+        )
+        # 返回值
+        # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
+        print(operation_info)
+
+    def delete_documentById(self, collection_name="test_collection", doc_id=None):
+        operation_info = self.client.delete(
+            collection_name=collection_name,
+            wait=True,
+            points_selector=Filter(
+                should=[
+                    FieldCondition(
+                        key="id",
+                        match=MatchValue(value=doc_id)
+                    )
+                ]
+            ),
+        )
+        # 返回值
+        # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
+        print(operation_info)
+
+    def query_data(self, collection_name="test_collection", query_vector=None, limit=10, filter_map=None):
+        if filter_map is not None:
+            # 遍历 filter_map 的 key和value
+            query_filter = Filter(
+                must=[
+                    FieldCondition(
+                        key=key,
+                        match=MatchValue(value=value)
+                    )
+                    for key, value in filter_map.items()
+                ]
+            )
+        else:
+            query_filter = None
+
+        search_result = self.client.search(
+            # 设置索引
+            collection_name=collection_name,
+            # 查询向量
+            query_vector=query_vector,
+            # 限制返回值的数量
+            limit=limit,
+            query_filter=query_filter,
+        )
+        # 返回值
+        # [ScoredPoint(id=4, version=0, score=1.362, payload={'city': 'New York'}, vector=None, shard_key=None), ScoredPoint(id=1, version=0, score=1.273, payload={'city': 'Berlin'}, vector=None, shard_key=None), ScoredPoint(id=3, version=0, score=1.208, payload={'city': 'Moscow'}, vector=None, shard_key=None)]
+        return search_result
+
+
+if __name__ == '__main__':
+    client = MyVector("101.43.195.48", port=6333, grpc_port=6334, api_key="tianyunperfect123456", embedding_path='/Users/alvin/IdeaProjects/python-base/vector_test/models--shibing624--text2vec-base-chinese')
+    # 1 创建索引
+    # client.create_collection('memo_collection', size=768)
+
+    # 2 添加数据
+    # add_data()
+    # encode = client.model.encode("胡辉地址:中国江苏省南京市雨花台区安德门大街23号金地威新B栋8楼")
+    # client.create_document(
+    #     collection_name='memo_collection',
+    #     doc_id=1,
+    #     vector=encode,
+    # )
+    client.delete_collection('memo_collection')
+    client.create_collection('memo_collection', size=768)
+
+    # 3 查询数据
+    data = client.query_data(collection_name='memo_collection', query_vector=client.model.encode("你知道小青的女儿吗?"), limit=10)
+    print(data)