Mercurial > code > home > repos > sco-bot
comparison query.py @ 2:82428652cda1
rewrite
author | drewp@bigasterisk.com |
---|---|
date | Wed, 03 Jul 2024 20:20:08 -0700 |
parents | ca5da75f03ee |
children |
comparison
equal
deleted
inserted
replaced
1:c2176e9a2696 | 2:82428652cda1 |
---|---|
1 import json | |
2 import sys | |
1 from pathlib import Path | 3 from pathlib import Path |
2 from pprint import pprint | 4 |
3 import re | 5 from tqdm import tqdm |
4 import sys | 6 |
7 from pymilvus import MilvusClient | |
8 from milvus_model.dense.onnx import OnnxEmbeddingFunction | |
9 | |
5 from extract_pdf import phrasesFromFile | 10 from extract_pdf import phrasesFromFile |
6 from pymilvus import model | 11 |
7 from pymilvus import MilvusClient | 12 |
13 def rebuild(client, embedding_fn, dim): | |
14 client.drop_collection(collection_name="demo_collection") | |
15 if not client.has_collection(collection_name="demo_collection"): | |
16 client.create_collection( | |
17 collection_name="demo_collection", | |
18 dimension=dim, | |
19 ) | |
20 | |
21 docs = [] | |
22 for i, (bbox, phrase) in tqdm(enumerate( | |
23 phrasesFromFile( | |
24 Path("data") / | |
25 "Meetings2226Minutes_20240702182359526 (1).pdf"))): | |
26 [vector] = embedding_fn.encode_documents([phrase]) | |
27 doc = { | |
28 "id": i, | |
29 "vector": vector, | |
30 "text": phrase, | |
31 "bbox": json.dumps(bbox), | |
32 } | |
33 docs.append(doc) | |
34 res = client.insert(collection_name="demo_collection", data=docs) | |
35 print('insert:', res) | |
36 | |
37 | |
38 def search(q, embedding_fn, client): | |
39 query_vectors = embedding_fn.encode_queries([q]) | |
40 | |
41 [query_result] = client.search( | |
42 collection_name="demo_collection", | |
43 data=query_vectors, | |
44 limit=5, | |
45 output_fields=["text"], | |
46 ) | |
47 query_result.sort(key=lambda x: x["distance"], reverse=True) | |
48 | |
49 for row in query_result: | |
50 print(f'{row["distance"]:.6f} {row["entity"]["text"]}') | |
51 | |
8 | 52 |
9 q, = sys.argv[1:] | 53 q, = sys.argv[1:] |
10 | 54 |
11 def cleanup(phrase: str) -> str: | 55 embedding_fn = OnnxEmbeddingFunction(model_name="GPTCache/paraphrase-albert-onnx") |
12 p = phrase.replace('\n', ' ') | |
13 p = re.sub(r'\s+', ' ', p) | |
14 if len(p) < 5: | |
15 return '' | |
16 return p | |
17 | |
18 | |
19 embedding_fn = model.DefaultEmbeddingFunction() | |
20 | |
21 client = MilvusClient("milvus_demo.db") | 56 client = MilvusClient("milvus_demo.db") |
22 | 57 rebuild(client, embedding_fn, dim=embedding_fn.dim) |
23 # client.drop_collection(collection_name="demo_collection") | 58 search(q, embedding_fn, client) |
24 # if not client.has_collection(collection_name="demo_collection"): | |
25 # client.create_collection( | |
26 # collection_name="demo_collection", | |
27 # dimension=768, # The vectors we will use in this demo has 768 dimensions | |
28 # ) | |
29 | |
30 # docs = [] | |
31 # for i, (bbox, phrase) in enumerate(phrasesFromFile(Path("data") / "Meetings2226Minutes_20240702182359526 (1).pdf")): | |
32 # phrase = cleanup(phrase) | |
33 # print(f'{phrase=}') | |
34 # if not phrase: | |
35 # continue | |
36 | |
37 # [vector] = embedding_fn.encode_documents([phrase]) | |
38 # doc = { | |
39 | |
40 # "id": i, | |
41 # "vector": vector, | |
42 # "text": phrase, | |
43 # } | |
44 # docs.append(doc) | |
45 # res = client.insert(collection_name="demo_collection", data=docs) | |
46 # print('insert:', res) | |
47 | |
48 query_vectors = embedding_fn.encode_queries([q]) | |
49 | |
50 [query_result] = client.search( | |
51 collection_name="demo_collection", | |
52 data=query_vectors, | |
53 limit=15, | |
54 output_fields=["text"], | |
55 ) | |
56 | |
57 for row in query_result: | |
58 print(f'{row["distance"]:.6f} {row["entity"]["text"]}') | |
59 # import ipdb; ipdb.set_trace() |