Project Files
src / retrieval / engine.py
"""
retrieval/engine.py — The unified RAG engine.
Orchestrates the full retrieval pipeline:
Ingest → Chunk → Embed → Store (ChromaDB)
Query → HyDE? → Dense + Sparse → RRF → Rerank → Return
This is the single class server.py talks to.
"""
from __future__ import annotations
import hashlib
import sys
from pathlib import Path
from typing import Literal
import chromadb
from chromadb.utils import embedding_functions
from config import (
CANDIDATE_MULTIPLIER,
DEFAULT_CHUNK_OVERLAP,
DEFAULT_CHUNK_SIZE,
DEFAULT_COLLECTION,
DEFAULT_EMBEDDING_MODEL,
DEFAULT_MIN_SCORE,
DEFAULT_TOP_K,
EMBEDDING_MODELS,
RERANKER_MODEL,
)
from src.ingestion.chunker import ChunkStrategy, chunk_text
from src.ingestion.extractors import SUPPORTED_EXTENSIONS, extract_text
from src.retrieval.bm25 import BM25PlusIndex
from src.retrieval.fusion import reciprocal_rank_fusion
from src.retrieval.reranker import CrossEncoderReranker
from src.utils.logging import get_logger
log = get_logger("retrieval.engine")
def _resolve_model(key: str) -> dict:
if key not in EMBEDDING_MODELS:
available = ", ".join(EMBEDDING_MODELS)
raise ValueError(f"Unknown embedding model '{key}'. Available: {available}")
return EMBEDDING_MODELS[key]
class RAGEngine:
"""
Full RAG pipeline: ingest documents → hybrid retrieval → optional reranking.
Parameters
──────────
docs_path : Folder containing documents to index.
collection_name : ChromaDB collection prefix.
embedding_model : Key from config.EMBEDDING_MODELS.
chunk_strategy : paragraph | sentence | semantic | fixed.
chunk_size : Target words per chunk.
chunk_overlap : Word overlap between consecutive chunks.
top_k : Default number of results to return.
min_score : Minimum cosine similarity to include a result.
use_hybrid : If True, combine BM25+ with dense retrieval via RRF.
use_reranker : If True, apply cross-encoder reranking.
reranker_model : Key from retrieval/reranker.RERANKER_MODELS.
use_hyde : If True, call HyDE before dense retrieval.
"""
def __init__(
self,
docs_path: str,
collection_name: str = DEFAULT_COLLECTION,
embedding_model: str = DEFAULT_EMBEDDING_MODEL,
chunk_strategy: ChunkStrategy = "paragraph",
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
top_k: int = DEFAULT_TOP_K,
min_score: float = DEFAULT_MIN_SCORE,
use_hybrid: bool = True,
use_reranker: bool = False,
reranker_model: str = RERANKER_MODEL,
use_hyde: bool = False,
) -> None:
self.docs_path = Path(docs_path)
self.chunk_strategy = chunk_strategy
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.top_k = top_k
self.min_score = min_score
self.use_hybrid = use_hybrid
self.use_hyde = use_hyde
# ── Embedding model ───────────────────────────────────────────────────
self.model_key = embedding_model
self.model_cfg = _resolve_model(embedding_model)
log.info("Embedding : %s", self.model_cfg["label"])
log.info("Tier : %s", self.model_cfg["tier"])
self._embed_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=self.model_cfg["model_name"]
)
# ── ChromaDB ──────────────────────────────────────────────────────────
db_path = str(self.docs_path.parent / ".chroma_db")
self._client = chromadb.PersistentClient(path=db_path)
# Encode model key into collection name so swapping models never
# accidentally uses stale embeddings from a previous model.
safe_key = embedding_model.replace("-", "_").replace(".", "_")
coll_name = f"{collection_name}__{safe_key}"
self._collection = self._client.get_or_create_collection(
name=coll_name,
embedding_function=self._embed_fn,
metadata={"hnsw:space": "cosine"},
)
log.info("ChromaDB : %s (collection: %s, %d chunks)",
db_path, coll_name, self._collection.count())
# ── BM25+ sparse index ────────────────────────────────────────────────
self._bm25: BM25PlusIndex = BM25PlusIndex()
self._bm25_ids: list[str] = [] # parallel to BM25 corpus
self._bm25_docs: list[str] = []
# ── Cross-encoder reranker ────────────────────────────────────────────
self._reranker: CrossEncoderReranker | None = None
if use_reranker:
try:
self._reranker = CrossEncoderReranker(reranker_model)
except ImportError as exc:
log.warning("Reranker disabled: %s", exc)
# ──────────────────────────────────────────────────────────────────────────
# Ingestion
# ──────────────────────────────────────────────────────────────────────────
def ingest(self) -> dict[str, int]:
"""
Scan docs_path, index new/changed files, skip already-indexed ones.
Returns:
{"new": N, "skipped": M, "total_chunks": K}
"""
files = [
f for f in self.docs_path.rglob("*")
if f.is_file() and f.suffix.lower() in SUPPORTED_EXTENSIONS
]
log.info("Found %d file(s) to check…", len(files))
new_count = skipped = 0
for file in files:
file_hash = _file_hash(file)
existing = self._collection.get(where={"file_hash": file_hash}, limit=1)
if existing["ids"]:
skipped += 1
continue
raw = extract_text(file)
if not raw.strip():
log.warning("Empty extraction: %s", file.name)
continue
chunks = chunk_text(
raw,
strategy=self.chunk_strategy,
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
if not chunks:
continue
ids, docs, metas = [], [], []
for i, chunk in enumerate(chunks):
# Apply model-specific document prefix at index time
prefixed = self.model_cfg["prefix_d"] + chunk
chunk_id = f"{file_hash}_{i}"
ids.append(chunk_id)
docs.append(prefixed)
metas.append({
"source": str(file.relative_to(self.docs_path)),
"file_hash": file_hash,
"chunk_index": i,
"raw_text": chunk, # unprefixed, for BM25 + display
"file_name": file.name,
"file_suffix": file.suffix.lower(),
})
self._collection.add(ids=ids, documents=docs, metadatas=metas)
new_count += 1
log.info("Indexed: %s (%d chunks)", file.name, len(chunks))
total = self._collection.count()
log.info(
"Ingest complete — new: %d skipped: %d total chunks: %d",
new_count, skipped, total,
)
if self.use_hybrid:
self._rebuild_bm25()
return {"new": new_count, "skipped": skipped, "total_chunks": total}
def _rebuild_bm25(self) -> None:
"""Reload all chunks from ChromaDB and rebuild the BM25+ index."""
data = self._collection.get(include=["metadatas"])
if not data or not data["ids"]:
return
self._bm25_ids = data["ids"]
# BM25 uses raw (unprefixed) text for keyword matching
self._bm25_docs = [m.get("raw_text", "") for m in data["metadatas"]]
self._bm25.build(self._bm25_docs)
log.info("BM25+ index built over %d chunks.", len(self._bm25_docs))
# ──────────────────────────────────────────────────────────────────────────
# Search
# ──────────────────────────────────────────────────────────────────────────
async def search(
self,
query: str,
top_k: int | None = None,
min_score: float | None = None,
) -> list[dict]:
"""
Run the full retrieval pipeline for a query.
Pipeline:
[HyDE] optional hypothesis generation
↓
[Dense] ChromaDB cosine similarity (CANDIDATE_MULTIPLIER × top_k)
↓
[BM25+] sparse keyword search (if use_hybrid)
↓
[RRF] Reciprocal Rank Fusion
↓
[Rerank] Cross-encoder reranking (if use_reranker)
↓
[Return] top_k results
Returns:
List of result dicts with keys:
text, source, file_name, score, rrf_score, rerank_score (opt)
"""
total = self._collection.count()
if total == 0:
return []
eff_top_k = max(1, top_k if top_k is not None else self.top_k)
eff_min_score = min_score if min_score is not None else self.min_score
candidate_k = min(eff_top_k * CANDIDATE_MULTIPLIER, total)
# ── HyDE ──────────────────────────────────────────────────────────────
retrieval_query = query
if self.use_hyde:
from src.retrieval.hyde import generate_hypothesis
hypothesis = await generate_hypothesis(query)
if hypothesis:
retrieval_query = hypothesis
# Apply model-specific query prefix
prefixed_query = self.model_cfg["prefix_q"] + retrieval_query
# ── Dense retrieval ───────────────────────────────────────────────────
dense_raw = self._collection.query(
query_texts=[prefixed_query],
n_results=candidate_k,
include=["documents", "metadatas", "distances"],
)
id_data: dict[str, dict] = {}
dense_ranked: list[tuple[str, float]] = []
for doc, meta, dist, doc_id in zip(
dense_raw["documents"][0],
dense_raw["metadatas"][0],
dense_raw["distances"][0],
dense_raw["ids"][0],
):
score = 1.0 - float(dist)
id_data[doc_id] = {
"text": meta.get("raw_text", doc),
"source": meta.get("source", "unknown"),
"file_name": meta.get("file_name", ""),
"score": score,
}
dense_ranked.append((doc_id, score))
# ── BM25+ sparse retrieval ────────────────────────────────────────────
if self.use_hybrid and self._bm25.doc_count > 0:
bm25_raw = self._bm25.score(query, top_k=candidate_k)
bm25_ranked: list[tuple[str, float]] = []
for idx, bm25_score in bm25_raw:
if idx >= len(self._bm25_ids):
continue
did = self._bm25_ids[idx]
bm25_ranked.append((did, bm25_score))
if did not in id_data:
# Pull from Chroma if not already in dense results
fetched = self._collection.get(
ids=[did], include=["documents", "metadatas"]
)
if fetched["ids"]:
m = fetched["metadatas"][0]
id_data[did] = {
"text": m.get("raw_text", fetched["documents"][0]),
"source": m.get("source", "unknown"),
"file_name": m.get("file_name", ""),
"score": 0.0,
}
fused = reciprocal_rank_fusion([dense_ranked, bm25_ranked])
else:
fused = [(did, sc) for did, sc in dense_ranked]
# ── Score filter & candidate assembly ────────────────────────────────
candidates: list[dict] = []
for doc_id, rrf_score in fused:
if doc_id not in id_data:
continue
data = id_data[doc_id]
if data["score"] < eff_min_score:
continue
candidates.append({
**data,
"rrf_score": rrf_score,
})
# Limit pool fed into reranker
candidates = candidates[: eff_top_k * 3]
# ── Cross-encoder reranking ───────────────────────────────────────────
if self._reranker is not None and candidates:
candidates = self._reranker.rerank(query, candidates)
else:
candidates.sort(key=lambda x: x["rrf_score"], reverse=True)
return candidates[:eff_top_k]
# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────
def _file_hash(file: Path) -> str:
digest = hashlib.md5(usedforsecurity=False)
digest.update(file.read_bytes())
return digest.hexdigest()