Forked from khtsly/persistent-memory
Project Files
src / retrieval / engine.ts
/**
* @file retrieval/engine.ts
* Retrieval engine: combines TF-IDF similarity, memory decay, access
* frequency, and confidence into a single composite score.
*
* Inspired by the SRLM paper's insight that multiple uncertainty signals
* (self-consistency, trace length, confidence) outperform any single one.
* We analogously blend multiple retrieval signals.
*/
import { TfIdfIndex } from "./tfidf";
import { MemoryDatabase } from "../storage/db";
import { srlmRerank } from "../processing/ai";
import {
DECAY_HALF_LIFE_DAYS,
DECAY_WEIGHT,
FREQUENCY_WEIGHT,
SIMILARITY_WEIGHT,
CONFIDENCE_WEIGHT,
MIN_RELEVANCE_THRESHOLD,
MAX_SEARCH_RESULTS,
} from "../constants";
import type { MemoryRecord, ScoredMemory, RetrievalResult } from "../types";
/**
* Compute exponential decay score based on time since last access.
* Returns a value in [0, 1] where 1 = just accessed, 0 = very old.
*/
function computeDecay(lastAccessedAt: number, halfLifeDays: number): number {
const now = Date.now();
const daysSinceAccess = (now - lastAccessedAt) / (24 * 60 * 60 * 1000);
// exp decay: score = 2^(-t/halfLife)
return Math.pow(2, -daysSinceAccess / halfLifeDays);
}
/**
* Normalize access count to [0, 1] using logarithmic scaling.
* This prevents a single heavily-accessed memory from dominating.
*/
function normalizeFrequency(
accessCount: number,
maxAccessCount: number,
): number {
if (maxAccessCount <= 0) return 0;
return Math.log(1 + accessCount) / Math.log(1 + maxAccessCount);
}
export class RetrievalEngine {
private tfIdf: TfIdfIndex;
private db: MemoryDatabase;
private maxAccessCount = 1;
constructor(db: MemoryDatabase) {
this.db = db;
this.tfIdf = new TfIdfIndex();
}
/**
* Build/rebuild the TF-IDF index from all stored memories.
* Called once at startup, then incrementally maintained.
*/
rebuildIndex(): void {
this.tfIdf.clear();
const allMemories = this.db.getAll(10_000);
this.maxAccessCount = 1;
for (const mem of allMemories) {
this.tfIdf.addDocument(
mem.id,
`${mem.content} ${mem.tags.join(" ")} ${mem.category}`,
);
if (mem.accessCount > this.maxAccessCount) {
this.maxAccessCount = mem.accessCount;
}
}
}
/** Add a single memory to the index (incremental). */
indexMemory(
id: string,
content: string,
tags: string[],
category: string,
): void {
this.tfIdf.addDocument(id, `${content} ${tags.join(" ")} ${category}`);
}
/** Remove a memory from the index. */
removeFromIndex(id: string): void {
this.tfIdf.removeDocument(id);
}
/**
* Retrieve memories ranked by composite score.
* Used by both the prompt preprocessor and the explicit tools.
*
* @param touchAccess If false, skip updating access counters.
* The preprocessor sets this to false to prevent auto-inject
* from artificially inflating access counts.
*/
retrieve(
query: string,
limit: number = MAX_SEARCH_RESULTS,
halfLifeDays: number = DECAY_HALF_LIFE_DAYS,
touchAccess: boolean = true,
): RetrievalResult {
const start = performance.now();
const candidateLimit = Math.min(limit * 3, 100);
const tfIdfResults = this.tfIdf.search(query, candidateLimit);
if (tfIdfResults.length === 0) {
const ftsResults = this.db.ftsSearch(query, limit);
if (ftsResults.length === 0) {
return {
memories: [],
totalMatched: 0,
queryTerms: [],
timeTakenMs: performance.now() - start,
};
}
return this.scoreAndRank(
ftsResults,
0.5,
limit,
halfLifeDays,
start,
query,
undefined,
touchAccess,
);
}
const ids = tfIdfResults.map(([id]) => id);
const memories = this.db.getByIds(ids);
const similarityMap = new Map<string, number>();
for (const [docId, score] of tfIdfResults) {
similarityMap.set(docId, score);
}
for (const mem of memories) {
if (mem.accessCount > this.maxAccessCount) {
this.maxAccessCount = mem.accessCount;
}
}
return this.scoreAndRank(
memories,
null,
limit,
halfLifeDays,
start,
query,
similarityMap,
touchAccess,
);
}
private scoreAndRank(
memories: MemoryRecord[],
flatSimilarity: number | null,
limit: number,
halfLifeDays: number,
startTime: number,
query: string,
similarityMap?: Map<string, number>,
touchAccess: boolean = true,
): RetrievalResult {
const scored: ScoredMemory[] = [];
for (const mem of memories) {
const similarity = flatSimilarity ?? similarityMap?.get(mem.id) ?? 0;
const decay = computeDecay(mem.lastAccessedAt, halfLifeDays);
const frequency = normalizeFrequency(
mem.accessCount,
this.maxAccessCount,
);
const confidence = mem.confidence;
const composite =
SIMILARITY_WEIGHT * similarity +
DECAY_WEIGHT * decay +
FREQUENCY_WEIGHT * frequency +
CONFIDENCE_WEIGHT * confidence;
if (composite < MIN_RELEVANCE_THRESHOLD) continue;
scored.push({
...mem,
relevanceScore: similarity,
decayScore: decay,
compositeScore: composite,
});
}
scored.sort((a, b) => b.compositeScore - a.compositeScore);
const results = scored.slice(0, limit);
if (touchAccess && results.length > 0) {
try {
this.db.touchAccessBatch(results.map((m) => m.id));
} catch {
}
}
const queryTerms = query
.toLowerCase()
.split(/\s+/)
.filter((t) => t.length >= 2);
return {
memories: results,
totalMatched: scored.length,
queryTerms,
timeTakenMs: performance.now() - startTime,
};
}
get indexStats() {
return this.tfIdf.stats;
}
/**
* SRLM-enhanced retrieval: standard retrieval + AI re-ranking.
*
* Two-phase approach:
* Phase 1: Fast TF-IDF + composite scoring (same as retrieve())
* Phase 2: SRLM re-ranking via K-candidate self-consistency + VC + trace length
*
* The SRLM scores are blended with the composite scores:
* finalScore = 0.6 × compositeScore + 0.4 × srlmScore
*
* Falls back to plain retrieve() if AI is unavailable or times out.
*/
async retrieveWithSRLM(
query: string,
limit: number = MAX_SEARCH_RESULTS,
halfLifeDays: number = DECAY_HALF_LIFE_DAYS,
K: number = 3,
): Promise<RetrievalResult> {
const baseResult = this.retrieve(
query,
Math.min(limit * 2, 30),
halfLifeDays,
true,
);
if (baseResult.memories.length === 0) return baseResult;
try {
const candidates = baseResult.memories.slice(0, 12).map((m) => ({
id: m.id,
content: m.content,
}));
const srlmScores = await srlmRerank(query, candidates, K);
if (srlmScores.size > 0) {
for (const mem of baseResult.memories) {
const srlmScore = srlmScores.get(mem.id);
if (srlmScore !== undefined) {
mem.compositeScore = 0.6 * mem.compositeScore + 0.4 * srlmScore;
} else {
mem.compositeScore *= 0.5;
}
}
baseResult.memories.sort((a, b) => b.compositeScore - a.compositeScore);
}
} catch {
}
baseResult.memories = baseResult.memories.slice(0, limit);
return baseResult;
}
}