Project Files
src / retrieval / reranker.py
"""
retrieval/reranker.py โ Cross-encoder reranking with multiple model options.
A cross-encoder jointly encodes (query, document) pairs and outputs a
relevance score that is far more accurate than bi-encoder cosine similarity.
The cost is O(top_k) forward passes at query time โ use after the fast
retrieval stage to re-score a small candidate pool (โค 50 chunks).
Available models (ranked by quality, slowest last)
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
ms-marco-MiniLM-L-6-v2 Fast ยท best for latency-sensitive setups
ms-marco-MiniLM-L-12-v2 Balanced ยท 2ร slower than L-6, noticeably better
ms-marco-electra-base Powerful ยท state-of-the-art on MS MARCO
"""
from __future__ import annotations
from typing import Literal
from src.utils.logging import get_logger
log = get_logger("retrieval.reranker")
RerankerModel = Literal[
"ms-marco-MiniLM-L-6-v2",
"ms-marco-MiniLM-L-12-v2",
"ms-marco-electra-base",
]
RERANKER_MODELS: dict[str, dict] = {
"ms-marco-MiniLM-L-6-v2": {
"hf_id": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"label": "MiniLM-L6 (fast)",
"description": "~22 ms/pair on CPU โ best default.",
},
"ms-marco-MiniLM-L-12-v2": {
"hf_id": "cross-encoder/ms-marco-MiniLM-L-12-v2",
"label": "MiniLM-L12 (balanced)",
"description": "~40 ms/pair on CPU โ noticeably better recall.",
},
"ms-marco-electra-base": {
"hf_id": "cross-encoder/ms-marco-electra-base",
"label": "Electra-base (powerful)",
"description": "~120 ms/pair on CPU โ highest reranking accuracy.",
},
}
DEFAULT_RERANKER = "ms-marco-MiniLM-L-6-v2"
class CrossEncoderReranker:
"""
Wraps sentence-transformers CrossEncoder for plug-and-play reranking.
"""
def __init__(self, model_key: str = DEFAULT_RERANKER) -> None:
try:
from sentence_transformers import CrossEncoder
except ImportError:
raise ImportError(
"sentence-transformers is required for reranking. "
"Install it with: pip install sentence-transformers"
)
cfg = RERANKER_MODELS.get(model_key)
if cfg is None:
available = ", ".join(RERANKER_MODELS)
raise ValueError(
f"Unknown reranker '{model_key}'. Available: {available}"
)
log.info("Loading reranker: %s", cfg["label"])
self.model = CrossEncoder(cfg["hf_id"])
self.model_key = model_key
log.info("Reranker ready: %s", cfg["description"])
def rerank(
self,
query: str,
candidates: list[dict],
text_key: str = "text",
) -> list[dict]:
"""
Score each candidate against the query and sort descending.
Args:
query: The user query string.
candidates: List of dicts; each must have a `text_key` field.
text_key: Key in each dict that holds the passage text.
Returns:
Same list, with `rerank_score` added, sorted by rerank_score DESC.
"""
if not candidates:
return candidates
pairs = [[query, c[text_key]] for c in candidates]
scores = self.model.predict(pairs).tolist()
for c, s in zip(candidates, scores):
c["rerank_score"] = float(s)
candidates.sort(key=lambda x: x["rerank_score"], reverse=True)
return candidates