123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- from qdrant_client import QdrantClient
- from qdrant_client.models import Distance, VectorParams
- from qdrant_client.models import PointStruct
- from qdrant_client.models import Filter, FieldCondition, MatchValue
- from text2vec import SentenceModel
- class MyVector:
- def __init__(self, ip, port=6333, grpc_port=6334, api_key=None, embedding_path=None):
- self.url = "http://" + ip
- self.port = port
- self.grpc_port = grpc_port
- self.api_key = api_key
- self.client = QdrantClient(
- url=self.url,
- port=self.port,
- grpc_port=self.grpc_port,
- api_key=self.api_key
- )
- self.model = SentenceModel(embedding_path)
- def create_collection(self, collection_name="test_collection", size=16):
- self.client.create_collection(
- # 设置索引的名称
- collection_name=collection_name,
- # 设置索引中输入向量的长度
- # 参数size是数据维度
- # 参数distance是计算的方法,主要有COSINE(余弦),EUCLID(欧氏距离)、DOT(点积),MANHATTAN(曼哈顿距离)
- vectors_config=VectorParams(size=size, distance=Distance.COSINE),
- )
- def delete_collection(self, collection_name="test_collection"):
- self.client.delete_collection(
- collection_name=collection_name,
- )
- def create_document(self, collection_name="test_collection", doc_id=None, vector=None, payload=None):
- operation_info = self.client.upsert(
- collection_name=collection_name,
- wait=True,
- points=[
- PointStruct(id=doc_id, vector=vector, payload=payload),
- ],
- )
- # 返回值
- # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
- print(operation_info)
- def delete_documentById(self, collection_name="test_collection", doc_id=None):
- operation_info = self.client.delete(
- collection_name=collection_name,
- wait=True,
- points_selector=Filter(
- should=[
- FieldCondition(
- key="id",
- match=MatchValue(value=doc_id)
- )
- ]
- ),
- )
- # 返回值
- # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
- print(operation_info)
- def query_data(self, collection_name="test_collection", query_vector=None, limit=10, filter_map=None):
- if filter_map is not None:
- # 遍历 filter_map 的 key和value
- query_filter = Filter(
- must=[
- FieldCondition(
- key=key,
- match=MatchValue(value=value)
- )
- for key, value in filter_map.items()
- ]
- )
- else:
- query_filter = None
- search_result = self.client.search(
- # 设置索引
- collection_name=collection_name,
- # 查询向量
- query_vector=query_vector,
- # 限制返回值的数量
- limit=limit,
- query_filter=query_filter,
- )
- # 返回值
- # [ScoredPoint(id=4, version=0, score=1.362, payload={'city': 'New York'}, vector=None, shard_key=None), ScoredPoint(id=1, version=0, score=1.273, payload={'city': 'Berlin'}, vector=None, shard_key=None), ScoredPoint(id=3, version=0, score=1.208, payload={'city': 'Moscow'}, vector=None, shard_key=None)]
- return search_result
- if __name__ == '__main__':
- client = MyVector("101.43.195.48", port=6333, grpc_port=6334, api_key="tianyunperfect123456", embedding_path='/Users/alvin/IdeaProjects/python-base/vector_test/models--shibing624--text2vec-base-chinese')
- # 1 创建索引
- # client.create_collection('memo_collection', size=768)
- # 2 添加数据
- # add_data()
- # encode = client.model.encode("胡辉地址:中国江苏省南京市雨花台区安德门大街23号金地威新B栋8楼")
- # client.create_document(
- # collection_name='memo_collection',
- # doc_id=1,
- # vector=encode,
- # )
- client.delete_collection('memo_collection')
- client.create_collection('memo_collection', size=768)
- # 3 查询数据
- data = client.query_data(collection_name='memo_collection', query_vector=client.model.encode("你知道小青的女儿吗?"), limit=10)
- print(data)
|