Examples

Simple demo

import asyncio
from typing import Optional

from vechord.embedding import GeminiDenseEmbedding
from vechord.registry import VechordRegistry
from vechord.spec import PrimaryKeyAutoIncrease, Table, Vector

DenseVector = Vector[3072]


class Document(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    title: str = ""
    text: str
    vec: DenseVector


async def main():
    async with (
        VechordRegistry(
            "simple",
            "postgresql://postgres:postgres@172.17.0.1:5432/",
            tables=[Document],
        ) as vr,
        GeminiDenseEmbedding() as emb,
    ):
        # add a document
        text = "my personal long note"
        doc = Document(
            title="note", text=text, vec=DenseVector(await emb.vectorize_chunk(text))
        )
        await vr.insert(doc)

        # load
        docs = await vr.select_by(Document.partial_init(), limit=1)
        print(docs)

        # query
        res = await vr.search_by_vector(
            Document, await emb.vectorize_query("note"), topk=1
        )
        print(res)

        # drop
        await vr.clear_storage(drop_table=True)


if __name__ == "__main__":
    asyncio.run(main())

BEIR evaluation

import csv
import zipfile
from collections.abc import AsyncIterator
from pathlib import Path

import httpx
import msgspec
import rich.progress

from vechord.embedding import GeminiDenseEmbedding
from vechord.evaluate import BaseEvaluator
from vechord.registry import VechordRegistry
from vechord.spec import Table, Vector

BASE_URL = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip"
DEFAULT_DATASET = "scifact"
TOP_K = 10

emb = GeminiDenseEmbedding()
DenseVector = Vector[3072]


def download_dataset(dataset: str, output: Path):
    output.mkdir(parents=True, exist_ok=True)
    zip = output / f"{dataset}.zip"
    if not zip.is_file():
        with (
            zip.open("wb") as f,
            httpx.stream("GET", BASE_URL.format(dataset)) as stream,
        ):
            total = int(stream.headers["Content-Length"])
            with rich.progress.Progress(
                "[progress.percentage]{task.percentage:>3.0f}%",
                rich.progress.BarColumn(bar_width=None),
                rich.progress.DownloadColumn(),
                rich.progress.TransferSpeedColumn(),
            ) as progress:
                download_task = progress.add_task("Download", total=total)
                for chunk in stream.iter_bytes():
                    f.write(chunk)
                    progress.update(
                        download_task, completed=stream.num_bytes_downloaded
                    )
    unzip_dir = output / dataset
    if not unzip_dir.is_dir():
        with zipfile.ZipFile(zip, "r") as f:
            f.extractall(output)
    return unzip_dir


class Corpus(Table):
    uid: str
    text: str
    title: str
    vector: DenseVector


class Query(Table):
    uid: str
    cid: str
    text: str
    vector: DenseVector


class Evaluation(msgspec.Struct):
    map: float
    ndcg: float
    recall: float


vr = VechordRegistry(
    DEFAULT_DATASET,
    "postgresql://postgres:postgres@172.17.0.1:5432/",
    tables=[Corpus, Query],
)


@vr.inject(output=Corpus)
async def load_corpus(dataset: str, output: Path) -> AsyncIterator[Corpus]:
    file = output / dataset / "corpus.jsonl"
    decoder = msgspec.json.Decoder()
    with file.open("r") as f:
        for line in f:
            item = decoder.decode(line)
            title = item.get("title", "")
            text = item.get("text", "")
            try:
                vector = await emb.vectorize_chunk(f"{title}\n{text}")
            except Exception as e:
                print(f"failed to vectorize {title}: {e}")
                continue
            yield Corpus(
                uid=item["_id"],
                text=text,
                title=title,
                vector=DenseVector(vector),
            )


@vr.inject(output=Query)
async def load_query(dataset: str, output: Path) -> AsyncIterator[Query]:
    file = output / dataset / "queries.jsonl"
    truth = output / dataset / "qrels" / "test.tsv"

    table = {}
    with open(truth, "r") as f:
        reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
        next(reader)  # skip header
        for row in reader:
            table[row[0]] = row[1]

    decoder = msgspec.json.Decoder()
    with file.open("r") as f:
        for line in f:
            item = decoder.decode(line)
            uid = item["_id"]
            if uid not in table:
                continue
            text = item.get("text", "")
            yield Query(
                uid=uid,
                cid=table[uid],
                text=text,
                vector=DenseVector(await emb.vectorize_query(text)),
            )


@vr.inject(input=Query)
async def evaluate(cid: str, vector: DenseVector) -> Evaluation:
    docs = await vr.search_by_vector(Corpus, vector, topk=TOP_K)
    score = BaseEvaluator.evaluate_one(cid, [doc.uid for doc in docs])
    return Evaluation(
        map=score.get("map"),
        ndcg=score.get("ndcg"),
        recall=score.get(f"recall_{TOP_K}"),
    )


async def main():
    save_dir = Path("datasets")
    download_dataset(DEFAULT_DATASET, save_dir)

    async with vr, emb:
        await load_corpus(DEFAULT_DATASET, save_dir)
        await load_query(DEFAULT_DATASET, save_dir)

        res: list[Evaluation] = await evaluate()
        print("ndcg", sum(r.ndcg for r in res) / len(res))
        print("recall@10", sum(r.recall for r in res) / len(res))


if __name__ == "__main__":
    import asyncio

    asyncio.run(main())

HTTP web service

from datetime import datetime, timezone
from functools import partial
from typing import Annotated

import httpx
import msgspec
import uvicorn

from vechord.chunk import RegexChunker
from vechord.embedding import GeminiDenseEmbedding
from vechord.extract import SimpleExtractor
from vechord.registry import VechordRegistry
from vechord.service import create_web_app
from vechord.spec import (
    ForeignKey,
    PrimaryKeyAutoIncrease,
    Table,
    Vector,
)

URL = "https://paulgraham.com/{}.html"
DenseVector = Vector[3072]
emb = GeminiDenseEmbedding()
chunker = RegexChunker(size=1024, overlap=0)
extractor = SimpleExtractor()


class Document(Table, kw_only=True):
    uid: PrimaryKeyAutoIncrease | None = None
    title: str = ""
    text: str
    updated_at: datetime = msgspec.field(
        default_factory=partial(datetime.now, timezone.utc)
    )


class Chunk(Table, kw_only=True):
    uid: PrimaryKeyAutoIncrease | None = None
    doc_id: Annotated[int, ForeignKey[Document.uid]]
    text: str
    vector: DenseVector


vr = VechordRegistry(
    "http", "postgresql://postgres:postgres@172.17.0.1:5432/", tables=[Document, Chunk]
)


@vr.inject(output=Document)
def load_document(title: str) -> Document:
    with httpx.Client() as client:
        resp = client.get(URL.format(title))
        if resp.is_error:
            raise RuntimeError(f"Failed to fetch the document `{title}`")
        return Document(title=title, text=extractor.extract_html(resp.text))


@vr.inject(input=Document, output=Chunk)
async def chunk_document(uid: int, text: str) -> list[Chunk]:
    chunks = await chunker.segment(text)
    return [
        Chunk(
            doc_id=uid, text=chunk, vector=DenseVector(await emb.vectorize_chunk(chunk))
        )
        for chunk in chunks
    ]


if __name__ == "__main__":
    # this pipeline will be used in the web app, or you can run it with `vr.run()`
    pipeline = vr.create_pipeline([load_document, chunk_document])
    app = create_web_app(vr, pipeline)

    uvicorn.run(app)

Contextual chunk augmentation

from datetime import datetime
from typing import Annotated, Optional

from rich import print

from vechord import (
    GeminiAugmenter,
    GeminiDenseEmbedding,
    GeminiEvaluator,
    LocalLoader,
    RegexChunker,
    SimpleExtractor,
)
from vechord.registry import VechordRegistry
from vechord.spec import (
    ForeignKey,
    PrimaryKeyAutoIncrease,
    Table,
    Vector,
)

emb = GeminiDenseEmbedding()
DenseVector = Vector[3072]
extractor = SimpleExtractor()


class Document(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    digest: str
    filename: str
    text: str
    updated_at: datetime


class Chunk(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    doc_uid: Annotated[int, ForeignKey[Document.uid]]
    seq_id: int
    text: str
    vector: DenseVector


class ContextChunk(Table, kw_only=True):
    chunk_uid: Annotated[int, ForeignKey[Chunk.uid]]
    text: str
    vector: DenseVector


vr = VechordRegistry(
    "decorator",
    "postgresql://postgres:postgres@172.17.0.1:5432/",
    tables=[Document, Chunk, ContextChunk],
)


@vr.inject(output=Document)
def load_from_dir(dirpath: str) -> list[Document]:
    loader = LocalLoader(dirpath, include=[".pdf"])
    return [
        Document(
            digest=doc.digest,
            filename=doc.path,
            text=extractor.extract(doc),
            updated_at=doc.updated_at,
        )
        for doc in loader.load()
    ]


@vr.inject(input=Document, output=Chunk)
async def split_document(uid: int, text: str) -> list[Chunk]:
    chunker = RegexChunker(overlap=0)
    chunks = await chunker.segment(text)
    return [
        Chunk(
            doc_uid=uid,
            seq_id=i,
            text=chunk,
            vector=DenseVector(await emb.vectorize_chunk(chunk)),
        )
        for i, chunk in enumerate(chunks)
    ]


@vr.inject(input=Document, output=ContextChunk)
async def context_embedding(uid: int, text: str) -> list[ContextChunk]:
    chunks: list[Chunk] = await vr.select_by(
        Chunk.partial_init(doc_uid=uid), fields=["uid", "text"]
    )
    async with GeminiAugmenter() as augmentor:
        context_chunks = [
            f"{context}\n{origin}"
            for (context, origin) in zip(
                augmentor.augment_context(text, [c.text for c in chunks]),
                [c.text for c in chunks],
                strict=False,
            )
        ]
    return [
        ContextChunk(
            chunk_uid=chunk_uid,
            text=augmented,
            vector=DenseVector(await emb.vectorize_chunk(augmented)),
        )
        for (chunk_uid, augmented) in zip(
            [c.uid for c in chunks], context_chunks, strict=False
        )
    ]


async def query_chunk(query: str) -> list[Chunk]:
    vector = await emb.vectorize_query(query)
    res: list[Chunk] = await vr.search_by_vector(Chunk, vector, topk=5)
    return res


async def query_context_chunk(query: str) -> list[ContextChunk]:
    vector = await emb.vectorize_query(query)
    res: list[ContextChunk] = await vr.search_by_vector(
        ContextChunk,
        vector,
        topk=5,
    )
    return res


@vr.inject(input=Chunk)
async def evaluate(uid: int, doc_uid: int, text: str):
    async with GeminiEvaluator() as evaluator:
        doc = (await vr.select_by(Document.partial_init(uid=doc_uid)))[0]
        query = await evaluator.produce_query(doc.text, text)
        retrieved = await query_chunk(query)
        score = evaluator.evaluate_one(str(uid), [str(r.uid) for r in retrieved])
    return score


async def main():
    async with vr, emb:
        await load_from_dir("./data")
        await split_document()
        await context_embedding()

        chunks = await query_chunk("vector search")
        print(chunks)

        scores = await evaluate()
        print(sum(scores) / len(scores))

        context_chunks = await query_context_chunk("vector search")
        print(context_chunks)

        await vr.clear_storage()


if __name__ == "__main__":
    import asyncio

    asyncio.run(main())

Contextual retrieval with the Anthropic example

"""Anthropic Cookbook Contextual Embedding Example.

Data can be found from "https://github.com/anthropics/anthropic-cookbook".
"""

import json
from pathlib import Path
from time import perf_counter
from typing import Annotated, Optional

import httpx

from vechord.augment import GeminiAugmenter
from vechord.embedding import GeminiDenseEmbedding
from vechord.registry import VechordRegistry
from vechord.rerank import CohereReranker, ReciprocalRankFusion
from vechord.spec import (
    ForeignKey,
    Keyword,
    PrimaryKeyAutoIncrease,
    Table,
    UniqueIndex,
    Vector,
)

DenseVector = Vector[3072]
emb = GeminiDenseEmbedding()


class Document(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    uuid: Annotated[str, UniqueIndex()]
    content: str


class Chunk(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    doc_uuid: Annotated[str, ForeignKey[Document.uuid]]
    index: int
    content: str
    vector: DenseVector
    keyword: Keyword


class ContextualChunk(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    doc_uuid: Annotated[str, ForeignKey[Document.uuid]]
    index: int
    content: str
    context: str
    vector: DenseVector
    keyword: Keyword


class Query(Table, kw_only=True):
    uid: Optional[PrimaryKeyAutoIncrease] = None
    content: str
    answer: str
    doc_uuids: list[str]
    chunk_index: list[int]
    vector: DenseVector


vr = VechordRegistry(
    "anthropic",
    "postgresql://postgres:postgres@172.17.0.1:5432/",
    tables=[Document, Chunk, ContextualChunk, Query],
)


def download_data(url: str, save_path: str):
    if Path(save_path).is_file():
        print(f"{save_path} already exists, skip download.")
        return
    with httpx.stream("GET", url) as response, open(save_path, "wb") as f:
        for chunk in response.iter_bytes():
            f.write(chunk)


async def load_data(filepath: str):
    with open(filepath, "r", encoding="utf-8") as f:
        docs = json.load(f)
        for doc in docs:
            await vr.insert(
                Document(
                    uuid=doc["original_uuid"],
                    content=doc["content"],
                )
            )
            for chunk in doc["chunks"]:
                await vr.insert(
                    Chunk(
                        doc_uuid=doc["original_uuid"],
                        index=chunk["original_index"],
                        content=chunk["content"],
                        vector=await emb.vectorize_chunk(chunk["content"]),
                        keyword=Keyword(chunk["content"]),
                    )
                )


async def load_contextual_chunks(filepath: str):
    async with GeminiAugmenter() as augmenter:
        with open(filepath, "r", encoding="utf-8") as f:
            docs = json.load(f)
            for doc in docs:
                chunks = doc["chunks"]
                augments = await augmenter.augment_context(
                    doc["content"], [chunk["content"] for chunk in chunks]
                )
                if len(augments) != len(chunks):
                    print(
                        f"augments length not match for uuid: {doc['original_uuid']}, {len(augments)} != {len(chunks)}"
                    )
                for chunk, context in zip(chunks, augments, strict=False):
                    contextual_content = f"{chunk['content']}\n\n{context}"
                    await vr.insert(
                        ContextualChunk(
                            doc_uuid=doc["original_uuid"],
                            index=chunk["original_index"],
                            content=chunk["content"],
                            context=context,
                            vector=await emb.vectorize_chunk(contextual_content),
                            keyword=Keyword(contextual_content),
                        )
                    )


async def load_query(filepath: str):
    queries = []
    with open(filepath, "r", encoding="utf-8") as f:
        for line in f:
            query = json.loads(line)
            queries.append(
                Query(
                    content=query["query"],
                    answer=query["answer"],
                    doc_uuids=[x[0] for x in query["golden_chunk_uuids"]],
                    chunk_index=[x[1] for x in query["golden_chunk_uuids"]],
                    vector=await emb.vectorize_query(query["query"]),
                )
            )
    await vr.copy_bulk(queries)


async def vector_search(query: Query, topk: int) -> list[Chunk]:
    return await vr.search_by_vector(Chunk, query.vector, topk=topk)


async def vector_contextual_search(query: Query, topk: int) -> list[ContextualChunk]:
    return await vr.search_by_vector(ContextualChunk, query.vector, topk=topk)


async def keyword_search(query: Query, topk: int) -> list[Chunk]:
    return await vr.search_by_keyword(Chunk, query.content, topk=topk)


async def keyword_contextual_search(query: Query, topk: int) -> list[ContextualChunk]:
    return await vr.search_by_keyword(ContextualChunk, query.content, topk=topk)


async def hybrid_search_fuse(query: Query, topk: int) -> list[Chunk]:
    rrf = ReciprocalRankFusion()
    return rrf.fuse(
        [await vector_search(query, topk), await keyword_search(query, topk)]
    )[:topk]


async def hybrid_contextual_search_fuse(
    query: Query, topk: int
) -> list[ContextualChunk]:
    rrf = ReciprocalRankFusion()
    return rrf.fuse(
        [
            await vector_contextual_search(query, topk),
            await keyword_contextual_search(query, topk),
        ]
    )[:topk]


async def hybrid_search_rerank(query: Query, topk: int, boost=3) -> list[Chunk]:
    vecs = await vector_search(query, topk * boost)
    keys = await keyword_search(query, topk * boost)
    chunks = list({chunk.uid: chunk for chunk in vecs + keys}.values())
    async with CohereReranker() as ranker:
        indices = await ranker.rerank(
            query.content, [chunk.content for chunk in chunks]
        )
        return [chunks[i] for i in indices[:topk]]


async def hybrid_contextual_search_rerank(
    query: Query, topk: int, boost=3
) -> list[ContextualChunk]:
    vecs = await vector_contextual_search(query, topk * boost)
    keys = await keyword_contextual_search(query, topk * boost)
    chunks = list({chunk.uid: chunk for chunk in vecs + keys}.values())
    async with CohereReranker() as ranker:
        indices = await ranker.rerank(
            query.content, [f"{chunk.content}\n{chunk.context}" for chunk in chunks]
        )
        return [chunks[i] for i in indices[:topk]]


async def evaluate(topk=5, search_func=vector_search):
    print(f"TopK={topk}, search by: {search_func.__name__}")
    queries: list[Query] = await vr.select_by(Query.partial_init())
    total_score = 0
    start = perf_counter()
    for query in queries:
        chunks: list[Chunk] = search_func(query, topk)
        count = 0
        for doc_uuid, chunk_index in zip(
            query.doc_uuids, query.chunk_index, strict=True
        ):
            for chunk in chunks:
                if chunk.doc_uuid == doc_uuid and chunk.index == chunk_index:
                    count += 1
                    break
        score = count / len(query.doc_uuids)
        total_score += score

    print(
        f"Pass@{topk}: {total_score / len(queries):.4f}, total queries: {len(queries)}, QPS: {len(queries) / (perf_counter() - start):.3f}"
    )


async def main(data_path: str):
    dir = Path(data_path)
    dir.mkdir(parents=True, exist_ok=True)
    download_data(
        "https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/codebase_chunks.json",
        dir / "codebase_chunks.json",
    )
    download_data(
        "https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/evaluation_set.jsonl",
        dir / "evaluation_set.jsonl",
    )
    async with vr, emb:
        await load_data(dir / "codebase_chunks.json")
        await load_query(dir / "evaluation_set.jsonl")
        await load_contextual_chunks(dir / "codebase_chunks.json")

        for topk in [5, 10]:
            print("=" * 50)
            await evaluate(topk=topk, search_func=vector_search)
            await evaluate(topk=topk, search_func=keyword_search)
            await evaluate(topk=topk, search_func=hybrid_search_fuse)
            await evaluate(topk=topk, search_func=hybrid_search_rerank)
            await evaluate(topk=topk, search_func=vector_contextual_search)
            await evaluate(topk=topk, search_func=keyword_contextual_search)
            await evaluate(topk=topk, search_func=hybrid_contextual_search_fuse)
            await evaluate(topk=topk, search_func=hybrid_contextual_search_rerank)


if __name__ == "__main__":
    import asyncio

    asyncio.run(main("datasets"))

Evaluate with generated queries

from dataclasses import dataclass
from typing import Annotated

import httpx

from vechord.chunk import RegexChunker
from vechord.embedding import GeminiDenseEmbedding
from vechord.evaluate import GeminiEvaluator
from vechord.extract import SimpleExtractor
from vechord.registry import VechordRegistry
from vechord.spec import (
    ForeignKey,
    PrimaryKeyAutoIncrease,
    Table,
    Vector,
)

URL = "https://paulgraham.com/{}.html"
ARTICLE = "best"
TOP_K = 10

DenseVector = Vector[3072]
emb = GeminiDenseEmbedding()
evaluator = GeminiEvaluator()
extractor = SimpleExtractor()


class Chunk(Table, kw_only=True):
    uid: PrimaryKeyAutoIncrease | None = None
    text: str
    vector: DenseVector


class Query(Table, kw_only=True):
    uid: PrimaryKeyAutoIncrease | None = None
    cid: Annotated[int, ForeignKey[Chunk.uid]]
    text: str
    vector: DenseVector


@dataclass(frozen=True)
class Evaluation:
    map: float
    ndcg: float
    recall: float


vr = VechordRegistry(
    ARTICLE, "postgresql://postgres:postgres@172.17.0.1:5432/", tables=[Chunk, Query]
)

with httpx.Client() as client:
    resp = client.get(URL.format(ARTICLE))
doc = extractor.extract_html(resp.text)


@vr.inject(output=Chunk)
async def segment_essay() -> list[Chunk]:
    chunker = RegexChunker()
    chunks = await chunker.segment(doc)
    return [
        Chunk(text=chunk, vector=DenseVector(await emb.vectorize_chunk(chunk)))
        for chunk in chunks
    ]


@vr.inject(input=Chunk, output=Query)
async def create_query(uid: int, text: str) -> Query:
    query = await evaluator.produce_query(doc, text)
    return Query(
        cid=uid, text=query, vector=DenseVector(await emb.vectorize_chunk(query))
    )


@vr.inject(input=Query)
async def evaluate(cid: int, vector: DenseVector) -> Evaluation:
    chunks: list[Chunk] = await vr.search_by_vector(Chunk, vector, topk=TOP_K)
    score = evaluator.evaluate_one(str(cid), [str(chunk.uid) for chunk in chunks])
    return Evaluation(
        map=score["map"], ndcg=score["ndcg"], recall=score[f"recall_{TOP_K}"]
    )


async def main():
    async with vr, emb, evaluator:
        await segment_essay()
        await create_query()

        res: list[Evaluation] = await evaluate()
        print("ndcg", sum(r.ndcg for r in res) / len(res))
        print(f"recall@{TOP_K}", sum(r.recall for r in res) / len(res))


if __name__ == "__main__":
    import asyncio

    asyncio.run(main())

Hybrid search with rerank

import httpx

from vechord.chunk import RegexChunker
from vechord.embedding import GeminiDenseEmbedding
from vechord.extract import SimpleExtractor
from vechord.registry import VechordRegistry
from vechord.rerank import CohereReranker
from vechord.spec import DefaultDocument, Keyword, create_chunk_with_dim

URL = "https://paulgraham.com/{}.html"
Chunk = create_chunk_with_dim(3072)
emb = GeminiDenseEmbedding()
chunker = RegexChunker(size=1024, overlap=0)
reranker = CohereReranker()
extractor = SimpleExtractor()


vr = VechordRegistry(
    "hybrid",
    "postgresql://postgres:postgres@172.17.0.1:5432/",
    tables=[DefaultDocument, Chunk],
)


@vr.inject(output=DefaultDocument)
async def load_document(title: str) -> DefaultDocument:
    async with httpx.AsyncClient() as client:
        resp = await client.get(URL.format(title))
        if resp.is_error:
            raise RuntimeError(f"Failed to fetch the document `{title}`")
        return DefaultDocument(title=title, text=extractor.extract_html(resp.text))


@vr.inject(input=DefaultDocument, output=Chunk)
async def chunk_document(uid: int, text: str) -> list[Chunk]:
    chunks = await chunker.segment(text)
    return [
        Chunk(
            doc_id=uid,
            text=chunk,
            vec=await emb.vectorize_chunk(chunk),
            keyword=Keyword(chunk),
        )
        for chunk in chunks
    ]


async def search_and_rerank(query: str, topk: int) -> list[Chunk]:
    text_retrieves = await vr.search_by_keyword(Chunk, query, topk=topk)
    vec_retrievse = await vr.search_by_vector(
        Chunk, await emb.vectorize_query(query), topk=topk
    )
    chunks = list(
        {chunk.uid: chunk for chunk in text_retrieves + vec_retrievse}.values()
    )
    indices = await reranker.rerank(query, [chunk.text for chunk in chunks])
    return [chunks[i] for i in indices[:topk]]


async def main():
    async with vr, emb, reranker:
        await load_document("smart")
        await chunk_document()
        chunks = await search_and_rerank("smart", 3)
        print(chunks)


if __name__ == "__main__":
    import asyncio

    asyncio.run(main())

Entity-relation graph-like retrieval

import csv
import zipfile
from collections.abc import AsyncIterator
from pathlib import Path
from uuid import UUID

import httpx
import msgspec
import rich.progress

from vechord.embedding import GeminiDenseEmbedding
from vechord.entity import GeminiEntityRecognizer
from vechord.evaluate import BaseEvaluator
from vechord.registry import VechordRegistry
from vechord.spec import PrimaryKeyUUID, Table, Vector

BASE_URL = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip"
DEFAULT_DATASET = "scifact"
TOP_K = 10

DenseVector = Vector[3072]
emb = GeminiDenseEmbedding()
ner = GeminiEntityRecognizer()


def download_dataset(dataset: str, output: Path):
    output.mkdir(parents=True, exist_ok=True)
    zip = output / f"{dataset}.zip"
    if not zip.is_file():
        with (
            zip.open("wb") as f,
            httpx.stream("GET", BASE_URL.format(dataset)) as stream,
        ):
            total = int(stream.headers["Content-Length"])
            with rich.progress.Progress(
                "[progress.percentage]{task.percentage:>3.0f}%",
                rich.progress.BarColumn(bar_width=None),
                rich.progress.DownloadColumn(),
                rich.progress.TransferSpeedColumn(),
            ) as progress:
                download_task = progress.add_task("Download", total=total)
                for chunk in stream.iter_bytes():
                    f.write(chunk)
                    progress.update(
                        download_task, completed=stream.num_bytes_downloaded
                    )
    unzip_dir = output / dataset
    if not unzip_dir.is_dir():
        with zipfile.ZipFile(zip, "r") as f:
            f.extractall(output)
    return unzip_dir


class Chunk(Table, kw_only=True):
    uuid: PrimaryKeyUUID = msgspec.field(default_factory=PrimaryKeyUUID.factory)
    uid: str
    text: str
    vec: DenseVector
    ent_uuids: list[UUID]


class Query(Table):
    uid: str
    cid: str
    text: str
    vec: DenseVector


class Entity(Table, kw_only=True):
    uuid: PrimaryKeyUUID = msgspec.field(default_factory=PrimaryKeyUUID.factory)
    text: str
    label: str
    vec: DenseVector
    chunk_uuids: list[UUID]


class Relation(Table, kw_only=True):
    uuid: PrimaryKeyUUID = msgspec.field(default_factory=PrimaryKeyUUID.factory)
    source: UUID
    target: UUID
    text: str
    vec: DenseVector


class Evaluation(msgspec.Struct):
    map: float
    ndcg: float
    recall: float


vr = VechordRegistry(
    "graph",
    "postgresql://postgres:postgres@172.17.0.1:5432/",
    tables=[Chunk, Query, Entity, Relation],
)


@vr.inject(output=Chunk)
async def load_corpus(dataset: str, output: Path) -> AsyncIterator[Chunk]:
    file = output / dataset / "corpus.jsonl"
    decoder = msgspec.json.Decoder()
    entities: dict[str, Entity] = {}
    relations: dict[str, Relation] = {}
    with file.open("r") as f:
        for line in f:
            item = decoder.decode(line)
            text = f"{item['title']}\n{item['text']}"
            try:
                vector = await emb.vectorize_chunk(text)
            except Exception as e:
                print(f"failed to vectorize {text}: {e}")
                continue
            ents, rels = await ner.recognize_with_relations(text)
            for ent in ents:
                if ent.text not in entities:
                    entities[ent.text] = Entity(
                        text=ent.text,
                        label=ent.label,
                        vec=DenseVector(
                            await emb.vectorize_chunk(f"{ent.text} {ent.description}")
                        ),
                        chunk_uuids=[],
                    )
            for rel in rels:
                if rel.source.text not in entities or rel.target.text not in entities:
                    continue
                if rel.description not in relations:
                    relations[rel.description] = Relation(
                        source=entities[rel.source.text].uuid,
                        target=entities[rel.target.text].uuid,
                        text=rel.description,
                        vec=DenseVector(await emb.vectorize_chunk(rel.description)),
                    )

            chunk = Chunk(
                uid=item["_id"],
                text=text,
                vec=DenseVector(vector),
                ent_uuids=[entities[ent.text].uuid for ent in ents],
            )
            for ent in ents:
                entities[ent.text].chunk_uuids.append(chunk.uuid)
            yield chunk
    await vr.copy_bulk(list(entities.values()))
    await vr.copy_bulk(list(relations.values()))


@vr.inject(output=Query)
async def load_query(dataset: str, output: Path) -> AsyncIterator[Query]:
    file = output / dataset / "queries.jsonl"
    truth = output / dataset / "qrels" / "test.tsv"

    table = {}
    with open(truth, "r") as f:
        reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
        next(reader)  # skip header
        for row in reader:
            table[row[0]] = row[1]

    decoder = msgspec.json.Decoder()
    with file.open("r") as f:
        for line in f:
            item = decoder.decode(line)
            uid = item["_id"]
            if uid not in table:
                continue
            text = item.get("text", "")
            yield Query(
                uid=uid,
                cid=table[uid],
                text=text,
                vec=DenseVector(await emb.vectorize_query(text)),
            )


async def expand_by_text(text: str) -> list[Chunk]:
    ents = await ner.recognize(text)
    chunks = []
    for ent in ents:
        entity = await vr.select_by(Entity.partial_init(text=ent.text))
        if not entity:
            continue
        entity = entity[0]
        chunks.extend(
            res[0]
            for res in [
                await vr.select_by(Chunk.partial_init(uuid=chunk_uuid))
                for chunk_uuid in entity.chunk_uuids
            ]
        )
    return chunks


async def expand_by_graph(text: str, topk=3) -> list[Chunk]:
    ents, rels = await ner.recognize_with_relations(text)
    if not ents:
        return []
    entity_text = " ".join(f"{ent.text} {ent.description}" for ent in ents)
    similar_ents = await vr.search_by_vector(
        Entity, await emb.vectorize_query(entity_text), topk=topk
    )
    ents = set(ent.uuid for ent in similar_ents)
    if rels:
        relation_text = " ".join(rel.description for rel in rels)
        similar_rels = await vr.search_by_vector(
            Relation, await emb.vectorize_query(relation_text), topk=topk
        )
        ents |= set(rel.source for rel in similar_rels) | set(
            rel.target for rel in similar_rels
        )
    chunks = []
    for ent_uuid in ents:
        res = await vr.select_by(Entity.partial_init(uuid=ent_uuid))
        if not res:
            continue
        entity = res[0]
        chunks.extend(
            res[0]
            for res in [
                await vr.select_by(Chunk.partial_init(uuid=chunk_uuid))
                for chunk_uuid in entity.chunk_uuids
            ]
        )
    return chunks


@vr.inject(input=Query)
async def evaluate(cid: str, text: str, vec: DenseVector) -> Evaluation:
    chunks: list[Chunk] = await vr.search_by_vector(Chunk, vec, topk=TOP_K)
    expands = await expand_by_graph(text)
    final_chunks = list({chunk.uuid: chunk for chunk in chunks + expands}.values())
    # TODO: rerank
    score = BaseEvaluator.evaluate_one(cid, [doc.uid for doc in final_chunks])
    return Evaluation(
        map=score.get("map"),
        ndcg=score.get("ndcg"),
        recall=score.get(f"recall_{TOP_K}"),
    )


async def display_graph(save_to_file: bool = True):
    import matplotlib.pyplot as plt
    import networkx as nx

    graph = nx.Graph()
    rels: list[Relation] = await vr.select_by(Relation.partial_init())
    ent_table: dict[UUID, Entity] = {}
    for rel in rels:
        for uuid in (rel.source, rel.target):
            if uuid not in ent_table:
                ent = await vr.select_by(Entity.partial_init(uuid=uuid))
                if ent:
                    ent_table[uuid] = ent[0]

    ents = list(ent_table.values())
    edge_labels = {}
    for i, ent in enumerate(ents):
        graph.add_node(ent.uuid, label=ent.label, text=ent.text, index=i)
    for rel in rels:
        graph.add_edge(rel.source, rel.target, text=rel.text)
        edge_labels[(rel.source, rel.target)] = rel.text

    fig = plt.figure(figsize=(12, 10))
    pos = nx.spring_layout(graph, k=0.5, iterations=50, seed=42)
    nx.draw_networkx_nodes(
        graph,
        pos,
        node_size=2500,
        node_color="skyblue",
        alpha=0.7,
        edgecolors="black",
        linewidths=1,
    )
    nx.draw_networkx_edges(
        graph, pos, edgelist=graph.edges(), width=1, alpha=0.5, edge_color="gray"
    )
    nx.draw_networkx_labels(graph, pos, font_size=10)
    nx.draw_networkx_edge_labels(
        graph, pos, edge_labels=edge_labels, font_size=8, label_pos=0.5
    )
    plt.title("Entity-Relation Graph")
    plt.axis("off")
    plt.tight_layout()
    if save_to_file:
        with open("graph.png", "wb") as f:
            fig.savefig(f, format="png")
        return
    else:
        plt.show()


async def main():
    save_dir = Path("datasets")
    async with vr, emb, ner:
        await download_dataset(DEFAULT_DATASET, save_dir)

        await load_corpus(DEFAULT_DATASET, save_dir)
        await load_query(DEFAULT_DATASET, save_dir)

        res: list[Evaluation] = await evaluate()
        print("ndcg", sum(r.ndcg for r in res) / len(res))
        print("recall@10", sum(r.recall for r in res) / len(res))

        await display_graph()


if __name__ == "__main__":
    import asyncio

    asyncio.run(main())