from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams from qdrant_client.models import PointStruct from qdrant_client.models import Filter, FieldCondition, MatchValue from text2vec import SentenceModel class MyVector: def __init__(self, ip, port=6333, grpc_port=6334, api_key=None, embedding_path=None): self.url = "http://" + ip self.port = port self.grpc_port = grpc_port self.api_key = api_key self.client = QdrantClient( url=self.url, port=self.port, grpc_port=self.grpc_port, api_key=self.api_key ) self.model = SentenceModel(embedding_path) def create_collection(self, collection_name="test_collection", size=16): self.client.create_collection( # 设置索引的名称 collection_name=collection_name, # 设置索引中输入向量的长度 # 参数size是数据维度 # 参数distance是计算的方法,主要有COSINE(余弦),EUCLID(欧氏距离)、DOT(点积),MANHATTAN(曼哈顿距离) vectors_config=VectorParams(size=size, distance=Distance.COSINE), ) def delete_collection(self, collection_name="test_collection"): self.client.delete_collection( collection_name=collection_name, ) def create_document(self, collection_name="test_collection", doc_id=None, vector=None, payload=None): operation_info = self.client.upsert( collection_name=collection_name, wait=True, points=[ PointStruct(id=doc_id, vector=vector, payload=payload), ], ) # 返回值 # operation_id=0 status= print(operation_info) def delete_documentById(self, collection_name="test_collection", doc_id=None): operation_info = self.client.delete( collection_name=collection_name, wait=True, points_selector=Filter( should=[ FieldCondition( key="id", match=MatchValue(value=doc_id) ) ] ), ) # 返回值 # operation_id=0 status= print(operation_info) def query_data(self, collection_name="test_collection", query_vector=None, limit=10, filter_map=None): if filter_map is not None: # 遍历 filter_map 的 key和value query_filter = Filter( must=[ FieldCondition( key=key, match=MatchValue(value=value) ) for key, value in filter_map.items() ] ) else: query_filter = None search_result = self.client.search( # 设置索引 collection_name=collection_name, # 查询向量 query_vector=query_vector, # 限制返回值的数量 limit=limit, query_filter=query_filter, ) # 返回值 # [ScoredPoint(id=4, version=0, score=1.362, payload={'city': 'New York'}, vector=None, shard_key=None), ScoredPoint(id=1, version=0, score=1.273, payload={'city': 'Berlin'}, vector=None, shard_key=None), ScoredPoint(id=3, version=0, score=1.208, payload={'city': 'Moscow'}, vector=None, shard_key=None)] return search_result if __name__ == '__main__': client = MyVector("101.43.195.48", port=6333, grpc_port=6334, api_key="tianyunperfect123456", embedding_path='/Users/alvin/IdeaProjects/python-base/vector_test/models--shibing624--text2vec-base-chinese') # 1 创建索引 # client.create_collection('memo_collection', size=768) # 2 添加数据 # add_data() # encode = client.model.encode("胡辉地址:中国江苏省南京市雨花台区安德门大街23号金地威新B栋8楼") # client.create_document( # collection_name='memo_collection', # doc_id=1, # vector=encode, # ) client.delete_collection('memo_collection') client.create_collection('memo_collection', size=768) # 3 查询数据 data = client.query_data(collection_name='memo_collection', query_vector=client.model.encode("你知道小青的女儿吗?"), limit=10) print(data)