Examples¶
Simple demo¶
import asyncio
from typing import Optional
from vechord.embedding import GeminiDenseEmbedding
from vechord.registry import VechordRegistry
from vechord.spec import PrimaryKeyAutoIncrease, Table, Vector
DenseVector = Vector[3072]
class Document(Table, kw_only=True):
uid: Optional[PrimaryKeyAutoIncrease] = None
title: str = ""
text: str
vec: DenseVector
async def main():
async with (
VechordRegistry(
"simple",
"postgresql://postgres:postgres@172.17.0.1:5432/",
tables=[Document],
) as vr,
GeminiDenseEmbedding() as emb,
):
# add a document
text = "my personal long note"
doc = Document(
title="note", text=text, vec=DenseVector(await emb.vectorize_chunk(text))
)
await vr.insert(doc)
# load
docs = await vr.select_by(Document.partial_init(), limit=1)
print(docs)
# query
res = await vr.search_by_vector(
Document, await emb.vectorize_query("note"), topk=1
)
print(res)
# drop
await vr.clear_storage(drop_table=True)
if __name__ == "__main__":
asyncio.run(main())
BEIR evaluation¶
import csv
import zipfile
from collections.abc import AsyncIterator
from pathlib import Path
import httpx
import msgspec
import rich.progress
from vechord.embedding import GeminiDenseEmbedding
from vechord.evaluate import BaseEvaluator
from vechord.registry import VechordRegistry
from vechord.spec import Table, Vector
BASE_URL = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip"
DEFAULT_DATASET = "scifact"
TOP_K = 10
emb = GeminiDenseEmbedding()
DenseVector = Vector[3072]
def download_dataset(dataset: str, output: Path):
output.mkdir(parents=True, exist_ok=True)
zip = output / f"{dataset}.zip"
if not zip.is_file():
with (
zip.open("wb") as f,
httpx.stream("GET", BASE_URL.format(dataset)) as stream,
):
total = int(stream.headers["Content-Length"])
with rich.progress.Progress(
"[progress.percentage]{task.percentage:>3.0f}%",
rich.progress.BarColumn(bar_width=None),
rich.progress.DownloadColumn(),
rich.progress.TransferSpeedColumn(),
) as progress:
download_task = progress.add_task("Download", total=total)
for chunk in stream.iter_bytes():
f.write(chunk)
progress.update(
download_task, completed=stream.num_bytes_downloaded
)
unzip_dir = output / dataset
if not unzip_dir.is_dir():
with zipfile.ZipFile(zip, "r") as f:
f.extractall(output)
return unzip_dir
class Corpus(Table):
uid: str
text: str
title: str
vector: DenseVector
class Query(Table):
uid: str
cid: str
text: str
vector: DenseVector
class Evaluation(msgspec.Struct):
map: float
ndcg: float
recall: float
vr = VechordRegistry(
DEFAULT_DATASET,
"postgresql://postgres:postgres@172.17.0.1:5432/",
tables=[Corpus, Query],
)
@vr.inject(output=Corpus)
async def load_corpus(dataset: str, output: Path) -> AsyncIterator[Corpus]:
file = output / dataset / "corpus.jsonl"
decoder = msgspec.json.Decoder()
with file.open("r") as f:
for line in f:
item = decoder.decode(line)
title = item.get("title", "")
text = item.get("text", "")
try:
vector = await emb.vectorize_chunk(f"{title}\n{text}")
except Exception as e:
print(f"failed to vectorize {title}: {e}")
continue
yield Corpus(
uid=item["_id"],
text=text,
title=title,
vector=DenseVector(vector),
)
@vr.inject(output=Query)
async def load_query(dataset: str, output: Path) -> AsyncIterator[Query]:
file = output / dataset / "queries.jsonl"
truth = output / dataset / "qrels" / "test.tsv"
table = {}
with open(truth, "r") as f:
reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
next(reader) # skip header
for row in reader:
table[row[0]] = row[1]
decoder = msgspec.json.Decoder()
with file.open("r") as f:
for line in f:
item = decoder.decode(line)
uid = item["_id"]
if uid not in table:
continue
text = item.get("text", "")
yield Query(
uid=uid,
cid=table[uid],
text=text,
vector=DenseVector(await emb.vectorize_query(text)),
)
@vr.inject(input=Query)
async def evaluate(cid: str, vector: DenseVector) -> Evaluation:
docs = await vr.search_by_vector(Corpus, vector, topk=TOP_K)
score = BaseEvaluator.evaluate_one(cid, [doc.uid for doc in docs])
return Evaluation(
map=score.get("map"),
ndcg=score.get("ndcg"),
recall=score.get(f"recall_{TOP_K}"),
)
async def main():
save_dir = Path("datasets")
download_dataset(DEFAULT_DATASET, save_dir)
async with vr, emb:
await load_corpus(DEFAULT_DATASET, save_dir)
await load_query(DEFAULT_DATASET, save_dir)
res: list[Evaluation] = await evaluate()
print("ndcg", sum(r.ndcg for r in res) / len(res))
print("recall@10", sum(r.recall for r in res) / len(res))
if __name__ == "__main__":
import asyncio
asyncio.run(main())
HTTP web service¶
from datetime import datetime, timezone
from functools import partial
from typing import Annotated
import httpx
import msgspec
import uvicorn
from vechord.chunk import RegexChunker
from vechord.embedding import GeminiDenseEmbedding
from vechord.extract import SimpleExtractor
from vechord.registry import VechordRegistry
from vechord.service import create_web_app
from vechord.spec import (
ForeignKey,
PrimaryKeyAutoIncrease,
Table,
Vector,
)
URL = "https://paulgraham.com/{}.html"
DenseVector = Vector[3072]
emb = GeminiDenseEmbedding()
chunker = RegexChunker(size=1024, overlap=0)
extractor = SimpleExtractor()
class Document(Table, kw_only=True):
uid: PrimaryKeyAutoIncrease | None = None
title: str = ""
text: str
updated_at: datetime = msgspec.field(
default_factory=partial(datetime.now, timezone.utc)
)
class Chunk(Table, kw_only=True):
uid: PrimaryKeyAutoIncrease | None = None
doc_id: Annotated[int, ForeignKey[Document.uid]]
text: str
vector: DenseVector
vr = VechordRegistry(
"http", "postgresql://postgres:postgres@172.17.0.1:5432/", tables=[Document, Chunk]
)
@vr.inject(output=Document)
def load_document(title: str) -> Document:
with httpx.Client() as client:
resp = client.get(URL.format(title))
if resp.is_error:
raise RuntimeError(f"Failed to fetch the document `{title}`")
return Document(title=title, text=extractor.extract_html(resp.text))
@vr.inject(input=Document, output=Chunk)
async def chunk_document(uid: int, text: str) -> list[Chunk]:
chunks = await chunker.segment(text)
return [
Chunk(
doc_id=uid, text=chunk, vector=DenseVector(await emb.vectorize_chunk(chunk))
)
for chunk in chunks
]
if __name__ == "__main__":
# this pipeline will be used in the web app, or you can run it with `vr.run()`
pipeline = vr.create_pipeline([load_document, chunk_document])
app = create_web_app(vr, pipeline)
uvicorn.run(app)
Contextual chunk augmentation¶
from datetime import datetime
from typing import Annotated, Optional
from rich import print
from vechord import (
GeminiAugmenter,
GeminiDenseEmbedding,
GeminiEvaluator,
LocalLoader,
RegexChunker,
SimpleExtractor,
)
from vechord.registry import VechordRegistry
from vechord.spec import (
ForeignKey,
PrimaryKeyAutoIncrease,
Table,
Vector,
)
emb = GeminiDenseEmbedding()
DenseVector = Vector[3072]
extractor = SimpleExtractor()
class Document(Table, kw_only=True):
uid: Optional[PrimaryKeyAutoIncrease] = None
digest: str
filename: str
text: str
updated_at: datetime
class Chunk(Table, kw_only=True):
uid: Optional[PrimaryKeyAutoIncrease] = None
doc_uid: Annotated[int, ForeignKey[Document.uid]]
seq_id: int
text: str
vector: DenseVector
class ContextChunk(Table, kw_only=True):
chunk_uid: Annotated[int, ForeignKey[Chunk.uid]]
text: str
vector: DenseVector
vr = VechordRegistry(
"decorator",
"postgresql://postgres:postgres@172.17.0.1:5432/",
tables=[Document, Chunk, ContextChunk],
)
@vr.inject(output=Document)
def load_from_dir(dirpath: str) -> list[Document]:
loader = LocalLoader(dirpath, include=[".pdf"])
return [
Document(
digest=doc.digest,
filename=doc.path,
text=extractor.extract(doc),
updated_at=doc.updated_at,
)
for doc in loader.load()
]
@vr.inject(input=Document, output=Chunk)
async def split_document(uid: int, text: str) -> list[Chunk]:
chunker = RegexChunker(overlap=0)
chunks = await chunker.segment(text)
return [
Chunk(
doc_uid=uid,
seq_id=i,
text=chunk,
vector=DenseVector(await emb.vectorize_chunk(chunk)),
)
for i, chunk in enumerate(chunks)
]
@vr.inject(input=Document, output=ContextChunk)
async def context_embedding(uid: int, text: str) -> list[ContextChunk]:
chunks: list[Chunk] = await vr.select_by(
Chunk.partial_init(doc_uid=uid), fields=["uid", "text"]
)
async with GeminiAugmenter() as augmentor:
context_chunks = [
f"{context}\n{origin}"
for (context, origin) in zip(
augmentor.augment_context(text, [c.text for c in chunks]),
[c.text for c in chunks],
strict=False,
)
]
return [
ContextChunk(
chunk_uid=chunk_uid,
text=augmented,
vector=DenseVector(await emb.vectorize_chunk(augmented)),
)
for (chunk_uid, augmented) in zip(
[c.uid for c in chunks], context_chunks, strict=False
)
]
async def query_chunk(query: str) -> list[Chunk]:
vector = await emb.vectorize_query(query)
res: list[Chunk] = await vr.search_by_vector(Chunk, vector, topk=5)
return res
async def query_context_chunk(query: str) -> list[ContextChunk]:
vector = await emb.vectorize_query(query)
res: list[ContextChunk] = await vr.search_by_vector(
ContextChunk,
vector,
topk=5,
)
return res
@vr.inject(input=Chunk)
async def evaluate(uid: int, doc_uid: int, text: str):
async with GeminiEvaluator() as evaluator:
doc = (await vr.select_by(Document.partial_init(uid=doc_uid)))[0]
query = await evaluator.produce_query(doc.text, text)
retrieved = await query_chunk(query)
score = evaluator.evaluate_one(str(uid), [str(r.uid) for r in retrieved])
return score
async def main():
async with vr, emb:
await load_from_dir("./data")
await split_document()
await context_embedding()
chunks = await query_chunk("vector search")
print(chunks)
scores = await evaluate()
print(sum(scores) / len(scores))
context_chunks = await query_context_chunk("vector search")
print(context_chunks)
await vr.clear_storage()
if __name__ == "__main__":
import asyncio
asyncio.run(main())
Contextual retrieval with the Anthropic example¶
"""Anthropic Cookbook Contextual Embedding Example.
Data can be found from "https://github.com/anthropics/anthropic-cookbook".
"""
import json
from pathlib import Path
from time import perf_counter
from typing import Annotated, Optional
import httpx
from vechord.augment import GeminiAugmenter
from vechord.embedding import GeminiDenseEmbedding
from vechord.registry import VechordRegistry
from vechord.rerank import CohereReranker, ReciprocalRankFusion
from vechord.spec import (
ForeignKey,
Keyword,
PrimaryKeyAutoIncrease,
Table,
UniqueIndex,
Vector,
)
DenseVector = Vector[3072]
emb = GeminiDenseEmbedding()
class Document(Table, kw_only=True):
uid: Optional[PrimaryKeyAutoIncrease] = None
uuid: Annotated[str, UniqueIndex()]
content: str
class Chunk(Table, kw_only=True):
uid: Optional[PrimaryKeyAutoIncrease] = None
doc_uuid: Annotated[str, ForeignKey[Document.uuid]]
index: int
content: str
vector: DenseVector
keyword: Keyword
class ContextualChunk(Table, kw_only=True):
uid: Optional[PrimaryKeyAutoIncrease] = None
doc_uuid: Annotated[str, ForeignKey[Document.uuid]]
index: int
content: str
context: str
vector: DenseVector
keyword: Keyword
class Query(Table, kw_only=True):
uid: Optional[PrimaryKeyAutoIncrease] = None
content: str
answer: str
doc_uuids: list[str]
chunk_index: list[int]
vector: DenseVector
vr = VechordRegistry(
"anthropic",
"postgresql://postgres:postgres@172.17.0.1:5432/",
tables=[Document, Chunk, ContextualChunk, Query],
)
def download_data(url: str, save_path: str):
if Path(save_path).is_file():
print(f"{save_path} already exists, skip download.")
return
with httpx.stream("GET", url) as response, open(save_path, "wb") as f:
for chunk in response.iter_bytes():
f.write(chunk)
async def load_data(filepath: str):
with open(filepath, "r", encoding="utf-8") as f:
docs = json.load(f)
for doc in docs:
await vr.insert(
Document(
uuid=doc["original_uuid"],
content=doc["content"],
)
)
for chunk in doc["chunks"]:
await vr.insert(
Chunk(
doc_uuid=doc["original_uuid"],
index=chunk["original_index"],
content=chunk["content"],
vector=await emb.vectorize_chunk(chunk["content"]),
keyword=Keyword(chunk["content"]),
)
)
async def load_contextual_chunks(filepath: str):
async with GeminiAugmenter() as augmenter:
with open(filepath, "r", encoding="utf-8") as f:
docs = json.load(f)
for doc in docs:
chunks = doc["chunks"]
augments = await augmenter.augment_context(
doc["content"], [chunk["content"] for chunk in chunks]
)
if len(augments) != len(chunks):
print(
f"augments length not match for uuid: {doc['original_uuid']}, {len(augments)} != {len(chunks)}"
)
for chunk, context in zip(chunks, augments, strict=False):
contextual_content = f"{chunk['content']}\n\n{context}"
await vr.insert(
ContextualChunk(
doc_uuid=doc["original_uuid"],
index=chunk["original_index"],
content=chunk["content"],
context=context,
vector=await emb.vectorize_chunk(contextual_content),
keyword=Keyword(contextual_content),
)
)
async def load_query(filepath: str):
queries = []
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
query = json.loads(line)
queries.append(
Query(
content=query["query"],
answer=query["answer"],
doc_uuids=[x[0] for x in query["golden_chunk_uuids"]],
chunk_index=[x[1] for x in query["golden_chunk_uuids"]],
vector=await emb.vectorize_query(query["query"]),
)
)
await vr.copy_bulk(queries)
async def vector_search(query: Query, topk: int) -> list[Chunk]:
return await vr.search_by_vector(Chunk, query.vector, topk=topk)
async def vector_contextual_search(query: Query, topk: int) -> list[ContextualChunk]:
return await vr.search_by_vector(ContextualChunk, query.vector, topk=topk)
async def keyword_search(query: Query, topk: int) -> list[Chunk]:
return await vr.search_by_keyword(Chunk, query.content, topk=topk)
async def keyword_contextual_search(query: Query, topk: int) -> list[ContextualChunk]:
return await vr.search_by_keyword(ContextualChunk, query.content, topk=topk)
async def hybrid_search_fuse(query: Query, topk: int) -> list[Chunk]:
rrf = ReciprocalRankFusion()
return rrf.fuse(
[await vector_search(query, topk), await keyword_search(query, topk)]
)[:topk]
async def hybrid_contextual_search_fuse(
query: Query, topk: int
) -> list[ContextualChunk]:
rrf = ReciprocalRankFusion()
return rrf.fuse(
[
await vector_contextual_search(query, topk),
await keyword_contextual_search(query, topk),
]
)[:topk]
async def hybrid_search_rerank(query: Query, topk: int, boost=3) -> list[Chunk]:
vecs = await vector_search(query, topk * boost)
keys = await keyword_search(query, topk * boost)
chunks = list({chunk.uid: chunk for chunk in vecs + keys}.values())
async with CohereReranker() as ranker:
indices = await ranker.rerank(
query.content, [chunk.content for chunk in chunks]
)
return [chunks[i] for i in indices[:topk]]
async def hybrid_contextual_search_rerank(
query: Query, topk: int, boost=3
) -> list[ContextualChunk]:
vecs = await vector_contextual_search(query, topk * boost)
keys = await keyword_contextual_search(query, topk * boost)
chunks = list({chunk.uid: chunk for chunk in vecs + keys}.values())
async with CohereReranker() as ranker:
indices = await ranker.rerank(
query.content, [f"{chunk.content}\n{chunk.context}" for chunk in chunks]
)
return [chunks[i] for i in indices[:topk]]
async def evaluate(topk=5, search_func=vector_search):
print(f"TopK={topk}, search by: {search_func.__name__}")
queries: list[Query] = await vr.select_by(Query.partial_init())
total_score = 0
start = perf_counter()
for query in queries:
chunks: list[Chunk] = search_func(query, topk)
count = 0
for doc_uuid, chunk_index in zip(
query.doc_uuids, query.chunk_index, strict=True
):
for chunk in chunks:
if chunk.doc_uuid == doc_uuid and chunk.index == chunk_index:
count += 1
break
score = count / len(query.doc_uuids)
total_score += score
print(
f"Pass@{topk}: {total_score / len(queries):.4f}, total queries: {len(queries)}, QPS: {len(queries) / (perf_counter() - start):.3f}"
)
async def main(data_path: str):
dir = Path(data_path)
dir.mkdir(parents=True, exist_ok=True)
download_data(
"https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/codebase_chunks.json",
dir / "codebase_chunks.json",
)
download_data(
"https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/evaluation_set.jsonl",
dir / "evaluation_set.jsonl",
)
async with vr, emb:
await load_data(dir / "codebase_chunks.json")
await load_query(dir / "evaluation_set.jsonl")
await load_contextual_chunks(dir / "codebase_chunks.json")
for topk in [5, 10]:
print("=" * 50)
await evaluate(topk=topk, search_func=vector_search)
await evaluate(topk=topk, search_func=keyword_search)
await evaluate(topk=topk, search_func=hybrid_search_fuse)
await evaluate(topk=topk, search_func=hybrid_search_rerank)
await evaluate(topk=topk, search_func=vector_contextual_search)
await evaluate(topk=topk, search_func=keyword_contextual_search)
await evaluate(topk=topk, search_func=hybrid_contextual_search_fuse)
await evaluate(topk=topk, search_func=hybrid_contextual_search_rerank)
if __name__ == "__main__":
import asyncio
asyncio.run(main("datasets"))
Evaluate with generated queries¶
from dataclasses import dataclass
from typing import Annotated
import httpx
from vechord.chunk import RegexChunker
from vechord.embedding import GeminiDenseEmbedding
from vechord.evaluate import GeminiEvaluator
from vechord.extract import SimpleExtractor
from vechord.registry import VechordRegistry
from vechord.spec import (
ForeignKey,
PrimaryKeyAutoIncrease,
Table,
Vector,
)
URL = "https://paulgraham.com/{}.html"
ARTICLE = "best"
TOP_K = 10
DenseVector = Vector[3072]
emb = GeminiDenseEmbedding()
evaluator = GeminiEvaluator()
extractor = SimpleExtractor()
class Chunk(Table, kw_only=True):
uid: PrimaryKeyAutoIncrease | None = None
text: str
vector: DenseVector
class Query(Table, kw_only=True):
uid: PrimaryKeyAutoIncrease | None = None
cid: Annotated[int, ForeignKey[Chunk.uid]]
text: str
vector: DenseVector
@dataclass(frozen=True)
class Evaluation:
map: float
ndcg: float
recall: float
vr = VechordRegistry(
ARTICLE, "postgresql://postgres:postgres@172.17.0.1:5432/", tables=[Chunk, Query]
)
with httpx.Client() as client:
resp = client.get(URL.format(ARTICLE))
doc = extractor.extract_html(resp.text)
@vr.inject(output=Chunk)
async def segment_essay() -> list[Chunk]:
chunker = RegexChunker()
chunks = await chunker.segment(doc)
return [
Chunk(text=chunk, vector=DenseVector(await emb.vectorize_chunk(chunk)))
for chunk in chunks
]
@vr.inject(input=Chunk, output=Query)
async def create_query(uid: int, text: str) -> Query:
query = await evaluator.produce_query(doc, text)
return Query(
cid=uid, text=query, vector=DenseVector(await emb.vectorize_chunk(query))
)
@vr.inject(input=Query)
async def evaluate(cid: int, vector: DenseVector) -> Evaluation:
chunks: list[Chunk] = await vr.search_by_vector(Chunk, vector, topk=TOP_K)
score = evaluator.evaluate_one(str(cid), [str(chunk.uid) for chunk in chunks])
return Evaluation(
map=score["map"], ndcg=score["ndcg"], recall=score[f"recall_{TOP_K}"]
)
async def main():
async with vr, emb, evaluator:
await segment_essay()
await create_query()
res: list[Evaluation] = await evaluate()
print("ndcg", sum(r.ndcg for r in res) / len(res))
print(f"recall@{TOP_K}", sum(r.recall for r in res) / len(res))
if __name__ == "__main__":
import asyncio
asyncio.run(main())
Hybrid search with rerank¶
import httpx
from vechord.chunk import RegexChunker
from vechord.embedding import GeminiDenseEmbedding
from vechord.extract import SimpleExtractor
from vechord.registry import VechordRegistry
from vechord.rerank import CohereReranker
from vechord.spec import DefaultDocument, Keyword, create_chunk_with_dim
URL = "https://paulgraham.com/{}.html"
Chunk = create_chunk_with_dim(3072)
emb = GeminiDenseEmbedding()
chunker = RegexChunker(size=1024, overlap=0)
reranker = CohereReranker()
extractor = SimpleExtractor()
vr = VechordRegistry(
"hybrid",
"postgresql://postgres:postgres@172.17.0.1:5432/",
tables=[DefaultDocument, Chunk],
)
@vr.inject(output=DefaultDocument)
async def load_document(title: str) -> DefaultDocument:
async with httpx.AsyncClient() as client:
resp = await client.get(URL.format(title))
if resp.is_error:
raise RuntimeError(f"Failed to fetch the document `{title}`")
return DefaultDocument(title=title, text=extractor.extract_html(resp.text))
@vr.inject(input=DefaultDocument, output=Chunk)
async def chunk_document(uid: int, text: str) -> list[Chunk]:
chunks = await chunker.segment(text)
return [
Chunk(
doc_id=uid,
text=chunk,
vec=await emb.vectorize_chunk(chunk),
keyword=Keyword(chunk),
)
for chunk in chunks
]
async def search_and_rerank(query: str, topk: int) -> list[Chunk]:
text_retrieves = await vr.search_by_keyword(Chunk, query, topk=topk)
vec_retrievse = await vr.search_by_vector(
Chunk, await emb.vectorize_query(query), topk=topk
)
chunks = list(
{chunk.uid: chunk for chunk in text_retrieves + vec_retrievse}.values()
)
indices = await reranker.rerank(query, [chunk.text for chunk in chunks])
return [chunks[i] for i in indices[:topk]]
async def main():
async with vr, emb, reranker:
await load_document("smart")
await chunk_document()
chunks = await search_and_rerank("smart", 3)
print(chunks)
if __name__ == "__main__":
import asyncio
asyncio.run(main())
Entity-relation graph-like retrieval¶
import csv
import zipfile
from collections.abc import AsyncIterator
from pathlib import Path
from uuid import UUID
import httpx
import msgspec
import rich.progress
from vechord.embedding import GeminiDenseEmbedding
from vechord.entity import GeminiEntityRecognizer
from vechord.evaluate import BaseEvaluator
from vechord.registry import VechordRegistry
from vechord.spec import PrimaryKeyUUID, Table, Vector
BASE_URL = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip"
DEFAULT_DATASET = "scifact"
TOP_K = 10
DenseVector = Vector[3072]
emb = GeminiDenseEmbedding()
ner = GeminiEntityRecognizer()
def download_dataset(dataset: str, output: Path):
output.mkdir(parents=True, exist_ok=True)
zip = output / f"{dataset}.zip"
if not zip.is_file():
with (
zip.open("wb") as f,
httpx.stream("GET", BASE_URL.format(dataset)) as stream,
):
total = int(stream.headers["Content-Length"])
with rich.progress.Progress(
"[progress.percentage]{task.percentage:>3.0f}%",
rich.progress.BarColumn(bar_width=None),
rich.progress.DownloadColumn(),
rich.progress.TransferSpeedColumn(),
) as progress:
download_task = progress.add_task("Download", total=total)
for chunk in stream.iter_bytes():
f.write(chunk)
progress.update(
download_task, completed=stream.num_bytes_downloaded
)
unzip_dir = output / dataset
if not unzip_dir.is_dir():
with zipfile.ZipFile(zip, "r") as f:
f.extractall(output)
return unzip_dir
class Chunk(Table, kw_only=True):
uuid: PrimaryKeyUUID = msgspec.field(default_factory=PrimaryKeyUUID.factory)
uid: str
text: str
vec: DenseVector
ent_uuids: list[UUID]
class Query(Table):
uid: str
cid: str
text: str
vec: DenseVector
class Entity(Table, kw_only=True):
uuid: PrimaryKeyUUID = msgspec.field(default_factory=PrimaryKeyUUID.factory)
text: str
label: str
vec: DenseVector
chunk_uuids: list[UUID]
class Relation(Table, kw_only=True):
uuid: PrimaryKeyUUID = msgspec.field(default_factory=PrimaryKeyUUID.factory)
source: UUID
target: UUID
text: str
vec: DenseVector
class Evaluation(msgspec.Struct):
map: float
ndcg: float
recall: float
vr = VechordRegistry(
"graph",
"postgresql://postgres:postgres@172.17.0.1:5432/",
tables=[Chunk, Query, Entity, Relation],
)
@vr.inject(output=Chunk)
async def load_corpus(dataset: str, output: Path) -> AsyncIterator[Chunk]:
file = output / dataset / "corpus.jsonl"
decoder = msgspec.json.Decoder()
entities: dict[str, Entity] = {}
relations: dict[str, Relation] = {}
with file.open("r") as f:
for line in f:
item = decoder.decode(line)
text = f"{item['title']}\n{item['text']}"
try:
vector = await emb.vectorize_chunk(text)
except Exception as e:
print(f"failed to vectorize {text}: {e}")
continue
ents, rels = await ner.recognize_with_relations(text)
for ent in ents:
if ent.text not in entities:
entities[ent.text] = Entity(
text=ent.text,
label=ent.label,
vec=DenseVector(
await emb.vectorize_chunk(f"{ent.text} {ent.description}")
),
chunk_uuids=[],
)
for rel in rels:
if rel.source.text not in entities or rel.target.text not in entities:
continue
if rel.description not in relations:
relations[rel.description] = Relation(
source=entities[rel.source.text].uuid,
target=entities[rel.target.text].uuid,
text=rel.description,
vec=DenseVector(await emb.vectorize_chunk(rel.description)),
)
chunk = Chunk(
uid=item["_id"],
text=text,
vec=DenseVector(vector),
ent_uuids=[entities[ent.text].uuid for ent in ents],
)
for ent in ents:
entities[ent.text].chunk_uuids.append(chunk.uuid)
yield chunk
await vr.copy_bulk(list(entities.values()))
await vr.copy_bulk(list(relations.values()))
@vr.inject(output=Query)
async def load_query(dataset: str, output: Path) -> AsyncIterator[Query]:
file = output / dataset / "queries.jsonl"
truth = output / dataset / "qrels" / "test.tsv"
table = {}
with open(truth, "r") as f:
reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
next(reader) # skip header
for row in reader:
table[row[0]] = row[1]
decoder = msgspec.json.Decoder()
with file.open("r") as f:
for line in f:
item = decoder.decode(line)
uid = item["_id"]
if uid not in table:
continue
text = item.get("text", "")
yield Query(
uid=uid,
cid=table[uid],
text=text,
vec=DenseVector(await emb.vectorize_query(text)),
)
async def expand_by_text(text: str) -> list[Chunk]:
ents = await ner.recognize(text)
chunks = []
for ent in ents:
entity = await vr.select_by(Entity.partial_init(text=ent.text))
if not entity:
continue
entity = entity[0]
chunks.extend(
res[0]
for res in [
await vr.select_by(Chunk.partial_init(uuid=chunk_uuid))
for chunk_uuid in entity.chunk_uuids
]
)
return chunks
async def expand_by_graph(text: str, topk=3) -> list[Chunk]:
ents, rels = await ner.recognize_with_relations(text)
if not ents:
return []
entity_text = " ".join(f"{ent.text} {ent.description}" for ent in ents)
similar_ents = await vr.search_by_vector(
Entity, await emb.vectorize_query(entity_text), topk=topk
)
ents = set(ent.uuid for ent in similar_ents)
if rels:
relation_text = " ".join(rel.description for rel in rels)
similar_rels = await vr.search_by_vector(
Relation, await emb.vectorize_query(relation_text), topk=topk
)
ents |= set(rel.source for rel in similar_rels) | set(
rel.target for rel in similar_rels
)
chunks = []
for ent_uuid in ents:
res = await vr.select_by(Entity.partial_init(uuid=ent_uuid))
if not res:
continue
entity = res[0]
chunks.extend(
res[0]
for res in [
await vr.select_by(Chunk.partial_init(uuid=chunk_uuid))
for chunk_uuid in entity.chunk_uuids
]
)
return chunks
@vr.inject(input=Query)
async def evaluate(cid: str, text: str, vec: DenseVector) -> Evaluation:
chunks: list[Chunk] = await vr.search_by_vector(Chunk, vec, topk=TOP_K)
expands = await expand_by_graph(text)
final_chunks = list({chunk.uuid: chunk for chunk in chunks + expands}.values())
# TODO: rerank
score = BaseEvaluator.evaluate_one(cid, [doc.uid for doc in final_chunks])
return Evaluation(
map=score.get("map"),
ndcg=score.get("ndcg"),
recall=score.get(f"recall_{TOP_K}"),
)
async def display_graph(save_to_file: bool = True):
import matplotlib.pyplot as plt
import networkx as nx
graph = nx.Graph()
rels: list[Relation] = await vr.select_by(Relation.partial_init())
ent_table: dict[UUID, Entity] = {}
for rel in rels:
for uuid in (rel.source, rel.target):
if uuid not in ent_table:
ent = await vr.select_by(Entity.partial_init(uuid=uuid))
if ent:
ent_table[uuid] = ent[0]
ents = list(ent_table.values())
edge_labels = {}
for i, ent in enumerate(ents):
graph.add_node(ent.uuid, label=ent.label, text=ent.text, index=i)
for rel in rels:
graph.add_edge(rel.source, rel.target, text=rel.text)
edge_labels[(rel.source, rel.target)] = rel.text
fig = plt.figure(figsize=(12, 10))
pos = nx.spring_layout(graph, k=0.5, iterations=50, seed=42)
nx.draw_networkx_nodes(
graph,
pos,
node_size=2500,
node_color="skyblue",
alpha=0.7,
edgecolors="black",
linewidths=1,
)
nx.draw_networkx_edges(
graph, pos, edgelist=graph.edges(), width=1, alpha=0.5, edge_color="gray"
)
nx.draw_networkx_labels(graph, pos, font_size=10)
nx.draw_networkx_edge_labels(
graph, pos, edge_labels=edge_labels, font_size=8, label_pos=0.5
)
plt.title("Entity-Relation Graph")
plt.axis("off")
plt.tight_layout()
if save_to_file:
with open("graph.png", "wb") as f:
fig.savefig(f, format="png")
return
else:
plt.show()
async def main():
save_dir = Path("datasets")
async with vr, emb, ner:
await download_dataset(DEFAULT_DATASET, save_dir)
await load_corpus(DEFAULT_DATASET, save_dir)
await load_query(DEFAULT_DATASET, save_dir)
res: list[Evaluation] = await evaluate()
print("ndcg", sum(r.ndcg for r in res) / len(res))
print("recall@10", sum(r.recall for r in res) / len(res))
await display_graph()
if __name__ == "__main__":
import asyncio
asyncio.run(main())