Project Files
src / ingestion / chunker.py
"""
ingestion/chunker.py β State-of-the-art semantic chunking strategies.
Strategy When to use
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
paragraph Default. Accumulates paragraphs up to chunk_size,
keeps overlap. Respects natural document structure.
sentence Splits on sentence boundaries (needs nltk punkt).
Best for QA over dense prose.
semantic Clusters sentences by embedding similarity, splits
at topic-shift boundaries (needs sentence-transformers).
Best accuracy, slowest indexing.
fixed Pure word-count. Fast, dumb, last resort.
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
"""
from __future__ import annotations
import re
from typing import Literal
from src.utils.logging import get_logger
log = get_logger("ingestion.chunker")
ChunkStrategy = Literal["paragraph", "sentence", "semantic", "fixed"]
# ββ Optional imports ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
try:
import nltk
# Ensure punkt tokenizer is available
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt", quiet=True)
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
try:
nltk.download("punkt_tab", quiet=True)
except Exception:
pass
_NLTK = True
except ImportError:
_NLTK = False
try:
import numpy as np
from sentence_transformers import SentenceTransformer
_ST_FOR_SEMANTIC = True
except ImportError:
_ST_FOR_SEMANTIC = False
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Public API
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def chunk_text(
text: str,
strategy: ChunkStrategy = "paragraph",
chunk_size: int = 512,
chunk_overlap: int = 64,
semantic_model: str = "sentence-transformers/all-MiniLM-L6-v2",
semantic_threshold: float = 0.35,
) -> list[str]:
"""
Split `text` into chunks using the chosen strategy.
Args:
text: Raw document text.
strategy: One of paragraph | sentence | semantic | fixed.
chunk_size: Target word count per chunk.
chunk_overlap: Words of context to carry into the next chunk.
semantic_model: Model used ONLY for the 'semantic' strategy.
semantic_threshold: Cosine-distance threshold for topic-shift detection.
Returns:
List of non-empty chunk strings.
"""
if not text.strip():
return []
if strategy == "paragraph":
return _paragraph_chunks(text, chunk_size, chunk_overlap)
if strategy == "sentence":
return _sentence_chunks(text, chunk_size, chunk_overlap)
if strategy == "semantic":
return _semantic_chunks(text, chunk_size, chunk_overlap,
semantic_model, semantic_threshold)
# "fixed" fallback
return _fixed_chunks(text, chunk_size, chunk_overlap)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Strategy implementations
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _paragraph_chunks(text: str, chunk_size: int, overlap: int) -> list[str]:
"""
Accumulate paragraphs (blank-line separated) until chunk_size words.
Very long single paragraphs are sub-chunked with fixed strategy.
"""
paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
chunks: list[str] = []
buffer: list[str] = [] # words
def flush() -> None:
joined = " ".join(buffer).strip()
if joined:
chunks.append(joined)
for para in paragraphs:
words = para.split()
if len(words) > chunk_size:
# Flush buffer first
if buffer:
flush()
buffer = buffer[-overlap:]
# Sub-chunk the long paragraph
sub = _fixed_chunks(para, chunk_size, overlap)
chunks.extend(sub[:-1])
# Keep last sub-chunk's tail as new buffer
buffer = sub[-1].split()[-overlap:] if sub else []
continue
if len(buffer) + len(words) > chunk_size:
flush()
buffer = buffer[-overlap:] + words
else:
buffer.extend(words)
if buffer:
flush()
return chunks
def _sentence_chunks(text: str, chunk_size: int, overlap: int) -> list[str]:
"""
Split into sentences (nltk), then accumulate into word-count windows.
Falls back to paragraph chunking if nltk is unavailable.
"""
if not _NLTK:
log.warning("nltk not installed; falling back to paragraph chunking.")
return _paragraph_chunks(text, chunk_size, overlap)
sentences = nltk.sent_tokenize(text)
chunks: list[str] = []
buffer: list[str] = [] # words
def flush() -> None:
joined = " ".join(buffer).strip()
if joined:
chunks.append(joined)
for sent in sentences:
words = sent.split()
if len(buffer) + len(words) > chunk_size:
flush()
buffer = buffer[-overlap:] + words
else:
buffer.extend(words)
if buffer:
flush()
return chunks
def _semantic_chunks(
text: str,
chunk_size: int,
overlap: int,
model_name: str,
threshold: float,
) -> list[str]:
"""
Semantic chunking via cosine-distance topic-shift detection.
Algorithm:
1. Split text into sentences.
2. Embed every sentence with a small SentenceTransformer.
3. Compute cosine distance between consecutive sentence embeddings.
4. Where distance > threshold β topic shift β start a new chunk.
5. Merge resulting segments into word-count windows with overlap.
Reference: Greg Kamradt's "5 Levels of Text Splitting" (2024).
"""
if not _NLTK or not _ST_FOR_SEMANTIC:
log.warning(
"nltk or sentence-transformers unavailable; "
"falling back to paragraph chunking for semantic strategy."
)
return _paragraph_chunks(text, chunk_size, overlap)
sentences = nltk.sent_tokenize(text)
if len(sentences) < 2:
return _paragraph_chunks(text, chunk_size, overlap)
log.info("Semantic chunking: embedding %d sentencesβ¦", len(sentences))
# Use a lightweight model just for topic-shift detection
model = SentenceTransformer(model_name)
embeddings = model.encode(sentences, normalize_embeddings=True, show_progress_bar=False)
# Cosine distance = 1 - dot(a, b) for normalized vectors
distances = [
float(1.0 - np.dot(embeddings[i], embeddings[i + 1]))
for i in range(len(embeddings) - 1)
]
# Build groups of sentences separated by topic-shift boundaries
groups: list[list[str]] = []
current: list[str] = [sentences[0]]
for i, dist in enumerate(distances):
if dist > threshold:
groups.append(current)
current = []
current.append(sentences[i + 1])
if current:
groups.append(current)
# Now merge groups into word-count chunks
raw_segments = [" ".join(g) for g in groups]
merged: list[str] = []
buffer: list[str] = []
def flush() -> None:
joined = " ".join(buffer).strip()
if joined:
merged.append(joined)
for seg in raw_segments:
words = seg.split()
if len(words) > chunk_size:
if buffer:
flush()
buffer = buffer[-overlap:]
merged.extend(_fixed_chunks(seg, chunk_size, overlap))
continue
if len(buffer) + len(words) > chunk_size:
flush()
buffer = buffer[-overlap:] + words
else:
buffer.extend(words)
if buffer:
flush()
log.info("Semantic chunking: %d chunks produced.", len(merged))
return merged
def _fixed_chunks(text: str, chunk_size: int, overlap: int) -> list[str]:
"""Pure word-count chunking with overlap. Fastest, no structure awareness."""
words = text.split()
chunks: list[str] = []
i = 0
step = max(1, chunk_size - overlap)
while i < len(words):
chunks.append(" ".join(words[i: i + chunk_size]))
i += step
return [c for c in chunks if c.strip()]