annotate query.py @ 2:82428652cda1

rewrite
author drewp@bigasterisk.com
date Wed, 03 Jul 2024 20:20:08 -0700
parents ca5da75f03ee
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
2
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
1 import json
0
drewp@bigasterisk.com
parents:
diff changeset
2 import sys
2
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
3 from pathlib import Path
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
4
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
5 from tqdm import tqdm
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
6
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
7 from pymilvus import MilvusClient
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
8 from milvus_model.dense.onnx import OnnxEmbeddingFunction
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
9
0
drewp@bigasterisk.com
parents:
diff changeset
10 from extract_pdf import phrasesFromFile
2
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
11
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
12
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
13 def rebuild(client, embedding_fn, dim):
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
14 client.drop_collection(collection_name="demo_collection")
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
15 if not client.has_collection(collection_name="demo_collection"):
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
16 client.create_collection(
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
17 collection_name="demo_collection",
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
18 dimension=dim,
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
19 )
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
20
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
21 docs = []
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
22 for i, (bbox, phrase) in tqdm(enumerate(
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
23 phrasesFromFile(
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
24 Path("data") /
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
25 "Meetings2226Minutes_20240702182359526 (1).pdf"))):
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
26 [vector] = embedding_fn.encode_documents([phrase])
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
27 doc = {
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
28 "id": i,
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
29 "vector": vector,
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
30 "text": phrase,
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
31 "bbox": json.dumps(bbox),
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
32 }
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
33 docs.append(doc)
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
34 res = client.insert(collection_name="demo_collection", data=docs)
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
35 print('insert:', res)
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
36
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
37
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
38 def search(q, embedding_fn, client):
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
39 query_vectors = embedding_fn.encode_queries([q])
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
40
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
41 [query_result] = client.search(
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
42 collection_name="demo_collection",
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
43 data=query_vectors,
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
44 limit=5,
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
45 output_fields=["text"],
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
46 )
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
47 query_result.sort(key=lambda x: x["distance"], reverse=True)
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
48
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
49 for row in query_result:
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
50 print(f'{row["distance"]:.6f} {row["entity"]["text"]}')
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
51
0
drewp@bigasterisk.com
parents:
diff changeset
52
drewp@bigasterisk.com
parents:
diff changeset
53 q, = sys.argv[1:]
drewp@bigasterisk.com
parents:
diff changeset
54
2
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
55 embedding_fn = OnnxEmbeddingFunction(model_name="GPTCache/paraphrase-albert-onnx")
0
drewp@bigasterisk.com
parents:
diff changeset
56 client = MilvusClient("milvus_demo.db")
2
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
57 rebuild(client, embedding_fn, dim=embedding_fn.dim)
82428652cda1 rewrite
drewp@bigasterisk.com
parents: 0
diff changeset
58 search(q, embedding_fn, client)