view search/query.py @ 4:0e33c65f1904

playing with extractors
author drewp@bigasterisk.com
date Sat, 06 Jul 2024 16:42:36 -0700
parents query.py@82428652cda1
children f23b21bd0fce
line wrap: on
line source

import json
import sys
from pathlib import Path

from tqdm import tqdm

from pymilvus import MilvusClient
from milvus_model.dense.onnx import OnnxEmbeddingFunction

from extract_pdf import phrasesFromFile

from fastapi import FastAPI

def rebuild(client, embedding_fn, dim):
    client.drop_collection(collection_name="demo_collection")
    if not client.has_collection(collection_name="demo_collection"):
        client.create_collection(
            collection_name="demo_collection",
            dimension=dim,
        )

    docs = []
    for i, (bbox, phrase) in tqdm(enumerate(
            phrasesFromFile(
                Path("data") /
                "Meetings2226Minutes_20240702182359526 (1).pdf")),
                desc="rebuilding",
                unit=' phrase'):
        [vector] = embedding_fn.encode_documents([phrase])
        doc = {
            "id": i,
            "vector": vector,
            "text": phrase,
            "bbox": json.dumps(bbox),
        }
        docs.append(doc)
    res = client.insert(collection_name="demo_collection", data=docs)
    print('insert:', res['insert_count'])


def search(q, embedding_fn, client):
    query_vectors = embedding_fn.encode_queries([q])

    [query_result] = client.search(
        collection_name="demo_collection",
        data=query_vectors,
        limit=5,
        output_fields=["text"],
    )
    query_result.sort(key=lambda x: x["distance"], reverse=True)

    for row in query_result:
        print(f'{row["distance"]:.6f} {row["entity"]["text"]}')


# q, = sys.argv[1:]

# https://huggingface.co/models?pipeline_tag=feature-extraction&library=onnx&sort=trending
# embedding_fn = OnnxEmbeddingFunction(model_name="jinaai/jina-embeddings-v2-base-en")
# client = MilvusClient("milvus_demo.db")
# rebuild(client, embedding_fn, dim=embedding_fn.dim)
# search(q, embedding_fn, client)

app = FastAPI()


@app.get("/sco/query")
def read_query1(q: str|None):
    print(f'1 {q=}')
    return {"Hello": "World"}