Mercurial > code > home > repos > sco-bot
changeset 2:82428652cda1
rewrite
author | drewp@bigasterisk.com |
---|---|
date | Wed, 03 Jul 2024 20:20:08 -0700 |
parents | c2176e9a2696 |
children | ba1ce5921a4b |
files | query.py |
diffstat | 1 files changed, 52 insertions(+), 53 deletions(-) [+] |
line wrap: on
line diff
--- a/query.py Wed Jul 03 20:19:57 2024 -0700 +++ b/query.py Wed Jul 03 20:20:08 2024 -0700 @@ -1,59 +1,58 @@ -from pathlib import Path -from pprint import pprint -import re +import json import sys +from pathlib import Path + +from tqdm import tqdm + +from pymilvus import MilvusClient +from milvus_model.dense.onnx import OnnxEmbeddingFunction + from extract_pdf import phrasesFromFile -from pymilvus import model -from pymilvus import MilvusClient + + +def rebuild(client, embedding_fn, dim): + client.drop_collection(collection_name="demo_collection") + if not client.has_collection(collection_name="demo_collection"): + client.create_collection( + collection_name="demo_collection", + dimension=dim, + ) + + docs = [] + for i, (bbox, phrase) in tqdm(enumerate( + phrasesFromFile( + Path("data") / + "Meetings2226Minutes_20240702182359526 (1).pdf"))): + [vector] = embedding_fn.encode_documents([phrase]) + doc = { + "id": i, + "vector": vector, + "text": phrase, + "bbox": json.dumps(bbox), + } + docs.append(doc) + res = client.insert(collection_name="demo_collection", data=docs) + print('insert:', res) + + +def search(q, embedding_fn, client): + query_vectors = embedding_fn.encode_queries([q]) + + [query_result] = client.search( + collection_name="demo_collection", + data=query_vectors, + limit=5, + output_fields=["text"], + ) + query_result.sort(key=lambda x: x["distance"], reverse=True) + + for row in query_result: + print(f'{row["distance"]:.6f} {row["entity"]["text"]}') + q, = sys.argv[1:] -def cleanup(phrase: str) -> str: - p = phrase.replace('\n', ' ') - p = re.sub(r'\s+', ' ', p) - if len(p) < 5: - return '' - return p - - -embedding_fn = model.DefaultEmbeddingFunction() - +embedding_fn = OnnxEmbeddingFunction(model_name="GPTCache/paraphrase-albert-onnx") client = MilvusClient("milvus_demo.db") - -# client.drop_collection(collection_name="demo_collection") -# if not client.has_collection(collection_name="demo_collection"): -# client.create_collection( -# collection_name="demo_collection", -# dimension=768, # The vectors we will use in this demo has 768 dimensions -# ) - -# docs = [] -# for i, (bbox, phrase) in enumerate(phrasesFromFile(Path("data") / "Meetings2226Minutes_20240702182359526 (1).pdf")): -# phrase = cleanup(phrase) -# print(f'{phrase=}') -# if not phrase: -# continue - -# [vector] = embedding_fn.encode_documents([phrase]) -# doc = { - -# "id": i, -# "vector": vector, -# "text": phrase, -# } -# docs.append(doc) -# res = client.insert(collection_name="demo_collection", data=docs) -# print('insert:', res) - -query_vectors = embedding_fn.encode_queries([q]) - -[query_result] = client.search( - collection_name="demo_collection", - data=query_vectors, - limit=15, - output_fields=["text"], -) - -for row in query_result: - print(f'{row["distance"]:.6f} {row["entity"]["text"]}') -# import ipdb; ipdb.set_trace() \ No newline at end of file +rebuild(client, embedding_fn, dim=embedding_fn.dim) +search(q, embedding_fn, client)