vector_util.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from qdrant_client import QdrantClient
  2. from qdrant_client.models import Distance, VectorParams
  3. from qdrant_client.models import PointStruct
  4. from qdrant_client.models import Filter, FieldCondition, MatchValue
  5. from text2vec import SentenceModel
  6. class MyVector:
  7. def __init__(self, ip, port=6333, grpc_port=6334, api_key=None, embedding_path=None):
  8. self.url = "http://" + ip
  9. self.port = port
  10. self.grpc_port = grpc_port
  11. self.api_key = api_key
  12. self.client = QdrantClient(
  13. url=self.url,
  14. port=self.port,
  15. grpc_port=self.grpc_port,
  16. api_key=self.api_key
  17. )
  18. self.model = SentenceModel(embedding_path)
  19. def create_collection(self, collection_name="test_collection", size=16):
  20. self.client.create_collection(
  21. # 设置索引的名称
  22. collection_name=collection_name,
  23. # 设置索引中输入向量的长度
  24. # 参数size是数据维度
  25. # 参数distance是计算的方法,主要有COSINE(余弦),EUCLID(欧氏距离)、DOT(点积),MANHATTAN(曼哈顿距离)
  26. vectors_config=VectorParams(size=size, distance=Distance.COSINE),
  27. )
  28. def delete_collection(self, collection_name="test_collection"):
  29. self.client.delete_collection(
  30. collection_name=collection_name,
  31. )
  32. def create_document(self, collection_name="test_collection", doc_id=None, vector=None, payload=None):
  33. operation_info = self.client.upsert(
  34. collection_name=collection_name,
  35. wait=True,
  36. points=[
  37. PointStruct(id=doc_id, vector=vector, payload=payload),
  38. ],
  39. )
  40. # 返回值
  41. # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
  42. print(operation_info)
  43. def delete_documentById(self, collection_name="test_collection", doc_id=None):
  44. operation_info = self.client.delete(
  45. collection_name=collection_name,
  46. wait=True,
  47. points_selector=Filter(
  48. should=[
  49. FieldCondition(
  50. key="id",
  51. match=MatchValue(value=doc_id)
  52. )
  53. ]
  54. ),
  55. )
  56. # 返回值
  57. # operation_id=0 status=<UpdateStatus.COMPLETED: 'completed'>
  58. print(operation_info)
  59. def query_data(self, collection_name="test_collection", query_vector=None, limit=10, filter_map=None):
  60. if filter_map is not None:
  61. # 遍历 filter_map 的 key和value
  62. query_filter = Filter(
  63. must=[
  64. FieldCondition(
  65. key=key,
  66. match=MatchValue(value=value)
  67. )
  68. for key, value in filter_map.items()
  69. ]
  70. )
  71. else:
  72. query_filter = None
  73. search_result = self.client.search(
  74. # 设置索引
  75. collection_name=collection_name,
  76. # 查询向量
  77. query_vector=query_vector,
  78. # 限制返回值的数量
  79. limit=limit,
  80. query_filter=query_filter,
  81. )
  82. # 返回值
  83. # [ScoredPoint(id=4, version=0, score=1.362, payload={'city': 'New York'}, vector=None, shard_key=None), ScoredPoint(id=1, version=0, score=1.273, payload={'city': 'Berlin'}, vector=None, shard_key=None), ScoredPoint(id=3, version=0, score=1.208, payload={'city': 'Moscow'}, vector=None, shard_key=None)]
  84. return search_result
  85. if __name__ == '__main__':
  86. client = MyVector("101.43.195.48", port=6333, grpc_port=6334, api_key="tianyunperfect123456", embedding_path='/Users/alvin/IdeaProjects/python-base/vector_test/models--shibing624--text2vec-base-chinese')
  87. # 1 创建索引
  88. # client.create_collection('memo_collection', size=768)
  89. # 2 添加数据
  90. # add_data()
  91. # encode = client.model.encode("胡辉地址:中国江苏省南京市雨花台区安德门大街23号金地威新B栋8楼")
  92. # client.create_document(
  93. # collection_name='memo_collection',
  94. # doc_id=1,
  95. # vector=encode,
  96. # )
  97. client.delete_collection('memo_collection')
  98. client.create_collection('memo_collection', size=768)
  99. # 3 查询数据
  100. data = client.query_data(collection_name='memo_collection', query_vector=client.model.encode("你知道小青的女儿吗?"), limit=10)
  101. print(data)