comparison query.py @ 2:82428652cda1

rewrite
author drewp@bigasterisk.com
date Wed, 03 Jul 2024 20:20:08 -0700
parents ca5da75f03ee
children
comparison
equal deleted inserted replaced
1:c2176e9a2696 2:82428652cda1
1 import json
2 import sys
1 from pathlib import Path 3 from pathlib import Path
2 from pprint import pprint 4
3 import re 5 from tqdm import tqdm
4 import sys 6
7 from pymilvus import MilvusClient
8 from milvus_model.dense.onnx import OnnxEmbeddingFunction
9
5 from extract_pdf import phrasesFromFile 10 from extract_pdf import phrasesFromFile
6 from pymilvus import model 11
7 from pymilvus import MilvusClient 12
13 def rebuild(client, embedding_fn, dim):
14 client.drop_collection(collection_name="demo_collection")
15 if not client.has_collection(collection_name="demo_collection"):
16 client.create_collection(
17 collection_name="demo_collection",
18 dimension=dim,
19 )
20
21 docs = []
22 for i, (bbox, phrase) in tqdm(enumerate(
23 phrasesFromFile(
24 Path("data") /
25 "Meetings2226Minutes_20240702182359526 (1).pdf"))):
26 [vector] = embedding_fn.encode_documents([phrase])
27 doc = {
28 "id": i,
29 "vector": vector,
30 "text": phrase,
31 "bbox": json.dumps(bbox),
32 }
33 docs.append(doc)
34 res = client.insert(collection_name="demo_collection", data=docs)
35 print('insert:', res)
36
37
38 def search(q, embedding_fn, client):
39 query_vectors = embedding_fn.encode_queries([q])
40
41 [query_result] = client.search(
42 collection_name="demo_collection",
43 data=query_vectors,
44 limit=5,
45 output_fields=["text"],
46 )
47 query_result.sort(key=lambda x: x["distance"], reverse=True)
48
49 for row in query_result:
50 print(f'{row["distance"]:.6f} {row["entity"]["text"]}')
51
8 52
9 q, = sys.argv[1:] 53 q, = sys.argv[1:]
10 54
11 def cleanup(phrase: str) -> str: 55 embedding_fn = OnnxEmbeddingFunction(model_name="GPTCache/paraphrase-albert-onnx")
12 p = phrase.replace('\n', ' ')
13 p = re.sub(r'\s+', ' ', p)
14 if len(p) < 5:
15 return ''
16 return p
17
18
19 embedding_fn = model.DefaultEmbeddingFunction()
20
21 client = MilvusClient("milvus_demo.db") 56 client = MilvusClient("milvus_demo.db")
22 57 rebuild(client, embedding_fn, dim=embedding_fn.dim)
23 # client.drop_collection(collection_name="demo_collection") 58 search(q, embedding_fn, client)
24 # if not client.has_collection(collection_name="demo_collection"):
25 # client.create_collection(
26 # collection_name="demo_collection",
27 # dimension=768, # The vectors we will use in this demo has 768 dimensions
28 # )
29
30 # docs = []
31 # for i, (bbox, phrase) in enumerate(phrasesFromFile(Path("data") / "Meetings2226Minutes_20240702182359526 (1).pdf")):
32 # phrase = cleanup(phrase)
33 # print(f'{phrase=}')
34 # if not phrase:
35 # continue
36
37 # [vector] = embedding_fn.encode_documents([phrase])
38 # doc = {
39
40 # "id": i,
41 # "vector": vector,
42 # "text": phrase,
43 # }
44 # docs.append(doc)
45 # res = client.insert(collection_name="demo_collection", data=docs)
46 # print('insert:', res)
47
48 query_vectors = embedding_fn.encode_queries([q])
49
50 [query_result] = client.search(
51 collection_name="demo_collection",
52 data=query_vectors,
53 limit=15,
54 output_fields=["text"],
55 )
56
57 for row in query_result:
58 print(f'{row["distance"]:.6f} {row["entity"]["text"]}')
59 # import ipdb; ipdb.set_trace()