comparison search/query.py @ 4:0e33c65f1904

playing with extractors
author drewp@bigasterisk.com
date Sat, 06 Jul 2024 16:42:36 -0700
parents query.py@82428652cda1
children f23b21bd0fce
comparison
equal deleted inserted replaced
3:ba1ce5921a4b 4:0e33c65f1904
1 import json
2 import sys
3 from pathlib import Path
4
5 from tqdm import tqdm
6
7 from pymilvus import MilvusClient
8 from milvus_model.dense.onnx import OnnxEmbeddingFunction
9
10 from extract_pdf import phrasesFromFile
11
12 from fastapi import FastAPI
13
14 def rebuild(client, embedding_fn, dim):
15 client.drop_collection(collection_name="demo_collection")
16 if not client.has_collection(collection_name="demo_collection"):
17 client.create_collection(
18 collection_name="demo_collection",
19 dimension=dim,
20 )
21
22 docs = []
23 for i, (bbox, phrase) in tqdm(enumerate(
24 phrasesFromFile(
25 Path("data") /
26 "Meetings2226Minutes_20240702182359526 (1).pdf")),
27 desc="rebuilding",
28 unit=' phrase'):
29 [vector] = embedding_fn.encode_documents([phrase])
30 doc = {
31 "id": i,
32 "vector": vector,
33 "text": phrase,
34 "bbox": json.dumps(bbox),
35 }
36 docs.append(doc)
37 res = client.insert(collection_name="demo_collection", data=docs)
38 print('insert:', res['insert_count'])
39
40
41 def search(q, embedding_fn, client):
42 query_vectors = embedding_fn.encode_queries([q])
43
44 [query_result] = client.search(
45 collection_name="demo_collection",
46 data=query_vectors,
47 limit=5,
48 output_fields=["text"],
49 )
50 query_result.sort(key=lambda x: x["distance"], reverse=True)
51
52 for row in query_result:
53 print(f'{row["distance"]:.6f} {row["entity"]["text"]}')
54
55
56 # q, = sys.argv[1:]
57
58 # https://huggingface.co/models?pipeline_tag=feature-extraction&library=onnx&sort=trending
59 # embedding_fn = OnnxEmbeddingFunction(model_name="jinaai/jina-embeddings-v2-base-en")
60 # client = MilvusClient("milvus_demo.db")
61 # rebuild(client, embedding_fn, dim=embedding_fn.dim)
62 # search(q, embedding_fn, client)
63
64 app = FastAPI()
65
66
67 @app.get("/sco/query")
68 def read_query1(q: str|None):
69 print(f'1 {q=}')
70 return {"Hello": "World"}