changeset 2:82428652cda1

rewrite
author drewp@bigasterisk.com
date Wed, 03 Jul 2024 20:20:08 -0700
parents c2176e9a2696
children ba1ce5921a4b
files query.py
diffstat 1 files changed, 52 insertions(+), 53 deletions(-) [+]
line wrap: on
line diff
--- a/query.py	Wed Jul 03 20:19:57 2024 -0700
+++ b/query.py	Wed Jul 03 20:20:08 2024 -0700
@@ -1,59 +1,58 @@
-from pathlib import Path
-from pprint import pprint
-import re
+import json
 import sys
+from pathlib import Path
+
+from tqdm import tqdm
+
+from pymilvus import MilvusClient
+from milvus_model.dense.onnx import OnnxEmbeddingFunction
+
 from extract_pdf import phrasesFromFile
-from pymilvus import model
-from pymilvus import MilvusClient
+
+
+def rebuild(client, embedding_fn, dim):
+    client.drop_collection(collection_name="demo_collection")
+    if not client.has_collection(collection_name="demo_collection"):
+        client.create_collection(
+            collection_name="demo_collection",
+            dimension=dim,
+        )
+
+    docs = []
+    for i, (bbox, phrase) in tqdm(enumerate(
+            phrasesFromFile(
+                Path("data") /
+                "Meetings2226Minutes_20240702182359526 (1).pdf"))):
+        [vector] = embedding_fn.encode_documents([phrase])
+        doc = {
+            "id": i,
+            "vector": vector,
+            "text": phrase,
+            "bbox": json.dumps(bbox),
+        }
+        docs.append(doc)
+    res = client.insert(collection_name="demo_collection", data=docs)
+    print('insert:', res)
+
+
+def search(q, embedding_fn, client):
+    query_vectors = embedding_fn.encode_queries([q])
+
+    [query_result] = client.search(
+        collection_name="demo_collection",
+        data=query_vectors,
+        limit=5,
+        output_fields=["text"],
+    )
+    query_result.sort(key=lambda x: x["distance"], reverse=True)
+
+    for row in query_result:
+        print(f'{row["distance"]:.6f} {row["entity"]["text"]}')
+
 
 q, = sys.argv[1:]
 
-def cleanup(phrase: str) -> str:
-    p = phrase.replace('\n', ' ')
-    p = re.sub(r'\s+', ' ', p)
-    if len(p) < 5:
-        return ''
-    return p
-
-
-embedding_fn = model.DefaultEmbeddingFunction()
-
+embedding_fn = OnnxEmbeddingFunction(model_name="GPTCache/paraphrase-albert-onnx")
 client = MilvusClient("milvus_demo.db")
-
-# client.drop_collection(collection_name="demo_collection")
-# if not client.has_collection(collection_name="demo_collection"):
-#     client.create_collection(
-#         collection_name="demo_collection",
-#         dimension=768,  # The vectors we will use in this demo has 768 dimensions
-# )
-
-# docs = []
-# for i, (bbox, phrase) in enumerate(phrasesFromFile(Path("data") / "Meetings2226Minutes_20240702182359526 (1).pdf")):
-#     phrase = cleanup(phrase)
-#     print(f'{phrase=}')
-#     if not phrase:
-#         continue
-
-#     [vector] = embedding_fn.encode_documents([phrase])
-#     doc = {
-        
-#     "id": i,
-#     "vector": vector,
-#     "text": phrase,
-# }
-#     docs.append(doc)
-# res = client.insert(collection_name="demo_collection", data=docs)
-# print('insert:', res)
-
-query_vectors = embedding_fn.encode_queries([q])
-
-[query_result] = client.search(
-    collection_name="demo_collection",  
-    data=query_vectors,  
-    limit=15,  
-    output_fields=["text"],  
-)
-
-for row in query_result:
-    print(f'{row["distance"]:.6f} {row["entity"]["text"]}')
-# import ipdb; ipdb.set_trace()
\ No newline at end of file
+rebuild(client, embedding_fn, dim=embedding_fn.dim)
+search(q, embedding_fn, client)