Project Files
src / retrieval / bm25.py
"""
retrieval/bm25.py — Production BM25+ sparse index (zero dependencies).
BM25+ is an improvement over classic BM25 that adds a lower-bound δ on
term frequency to avoid giving zero score to rare-but-present terms.
Reference: Lv & Zhai, "Lower-Bounding Term Frequency Normalization" (2011).
"""
from __future__ import annotations
import math
import re
from collections import defaultdict
from dataclasses import dataclass, field
@dataclass
class BM25PlusIndex:
"""
In-memory BM25+ index.
Usage:
index = BM25PlusIndex()
index.build(list_of_strings)
results = index.score("my query", top_k=20)
"""
k1: float = 1.5 # term-frequency saturation
b: float = 0.75 # length normalization strength
delta: float = 0.5 # BM25+ lower bound on TF contribution
# Internal state
_corpus: list[list[str]] = field(default_factory=list, repr=False)
_df: dict[str, int] = field(default_factory=lambda: defaultdict(int), repr=False)
_idf: dict[str, float] = field(default_factory=dict, repr=False)
_avgdl: float = field(default=1.0, repr=False)
_built: bool = field(default=False, repr=False)
# ── Public ────────────────────────────────────────────────────────────────
def build(self, documents: list[str]) -> None:
"""Tokenize all documents and build the IDF table."""
self._corpus = [self._tokenize(d) for d in documents]
df: dict[str, int] = defaultdict(int)
total_len = 0
for tokens in self._corpus:
total_len += len(tokens)
for term in set(tokens):
df[term] += 1
n = len(self._corpus)
self._avgdl = total_len / n if n else 1.0
self._df = df
self._idf = {
term: math.log((n + 1) / (freq + 0.5))
for term, freq in df.items()
}
self._built = True
def score(self, query: str, top_k: int) -> list[tuple[int, float]]:
"""
Score all documents against `query`.
Returns:
List of (doc_index, score) sorted by score descending, len ≤ top_k.
"""
if not self._built or not self._corpus:
return []
query_terms = self._tokenize(query)
scores: list[float] = []
for tokens in self._corpus:
tf: dict[str, int] = defaultdict(int)
for t in tokens:
tf[t] += 1
dl = len(tokens)
s = 0.0
for term in query_terms:
if term not in self._idf:
continue
f = tf[term]
# BM25+ formula
tf_norm = ((self.k1 + 1) * f) / (
self.k1 * (1 - self.b + self.b * dl / self._avgdl) + f
) + self.delta
s += self._idf[term] * tf_norm
scores.append(s)
indexed = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
return [(idx, sc) for idx, sc in indexed[:top_k] if sc > 0]
@property
def doc_count(self) -> int:
return len(self._corpus)
# ── Internal ──────────────────────────────────────────────────────────────
@staticmethod
def _tokenize(text: str) -> list[str]:
return re.findall(r"\b\w+\b", text.lower())