view search/query.py @ 8:f23b21bd0fce

apex search
author drewp@bigasterisk.com
date Sun, 07 Jul 2024 16:26:56 -0700
parents 0e33c65f1904
children
line wrap: on
line source

from dataclasses import dataclass
import html
import json
from pprint import pprint
import sys
from pathlib import Path
from typing import Iterable

from tqdm import tqdm

from pymilvus import MilvusClient
from milvus_model.dense.onnx import OnnxEmbeddingFunction

from extract_pdf import files, phrasesFromFile

from fastapi import FastAPI
from search_apex import Search


def rebuild(client, embedding_fn, dim):
    client.drop_collection(collection_name="demo_collection")
    if not client.has_collection(collection_name="demo_collection"):
        client.create_collection(
            collection_name="demo_collection",
            dimension=dim,
        )

    docs = []
    for i, (bbox, phrase) in tqdm(enumerate(
            phrasesFromFile(
                Path("data") /
                "Meetings2226Minutes_20240702182359526 (1).pdf")),
                                  desc="rebuilding",
                                  unit=' phrase'):
        [vector] = embedding_fn.encode_documents([phrase])
        doc = {
            "id": i,
            "vector": vector,
            "text": phrase,
            "bbox": json.dumps(bbox),
        }
        docs.append(doc)
    res = client.insert(collection_name="demo_collection", data=docs)
    print('insert:', res['insert_count'])


def xxsearch(q, embedding_fn, client):
    query_vectors = embedding_fn.encode_queries([q])

    [query_result] = client.search(
        collection_name="demo_collection",
        data=query_vectors,
        limit=5,
        output_fields=["text"],
    )
    query_result.sort(key=lambda x: x["distance"], reverse=True)

    for row in query_result:
        print(f'{row["distance"]:.6f} {row["entity"]["text"]}')


# q, = sys.argv[1:]

# https://huggingface.co/models?pipeline_tag=feature-extraction&library=onnx&sort=trending
# embedding_fn = OnnxEmbeddingFunction(model_name="jinaai/jina-embeddings-v2-base-en")
# client = MilvusClient("milvus_demo.db")
# rebuild(client, embedding_fn, dim=embedding_fn.dim)
# search(q, embedding_fn, client)

app = FastAPI()

search = Search()


@app.get("/sco/query")
def read_query1(q: str):
    results = []
    results = search.search(q)

    return {"results": results}