Forked from
src / processing / ai.ts
/**
* @file processing/ai.ts
* AI-powered memory processing using the loaded LM Studio model.
*
* Implements SRLM uncertainty-aware self-reflection for memory retrieval,
* faithfully adapted from arXiv:2603.15653.
*
* FROM THE PAPER (§2.2):
* SRLM selects among K candidate context-interaction programs using three
* complementary uncertainty signals:
* 1. Self-Consistency — generate K candidates; keep the CONSISTENT SET S
* (those whose output matches the plurality answer). This is a filter,
* not a weight — candidates outside S are discarded before Stage 2.
* 2. Verbalized Confidence (VC) — model reports confidence per step,
* aggregated in log-space: VC(p) = Σ_t log(ν_t / 100) ≤ 0
* 3. Trace Length (Len) — total tokens across all steps; shorter = more decisive
* Joint score: s(p) = VC(p) · Len(p), select p* = argmax_{p∈S} s(p)
* Since both VC ≤ 0 and Len > 0, argmax picks the least-negative product.
*
* ADAPTATION FOR MEMORY RETRIEVAL:
* The paper selects a single best program; we re-rank a list of memories.
* The unit changes but the three-stage algorithm is preserved exactly:
*
* 1. Self-Consistency (filter): run K assessments; build the consistent set
* of memories selected by a majority (> K/2) of those assessments.
* 2. Verbalized Confidence: for each memory in the consistent set, compute
* avgLogVC = mean of log(ν_k / 100) across assessments that selected it.
* 3. Trace Length: compute avgTokenLen = mean token count of those assessments.
* Joint: s(m) = avgLogVC × avgTokenLen (≤ 0); argmax selects best memory.
* Normalize [s_min, s_max] → [0, 1] for downstream blending.
*
* Also provides: fact extraction, conflict detection.
* All calls are optional (graceful fallback if no model loaded).
* All calls have strict timeouts to never block the chat flow.
*/
import { LMStudioClient } from "@lmstudio/sdk";
import {
AI_EXTRACT_MAX_TOKENS,
AI_EXTRACT_TEMPERATURE,
AI_CALL_TIMEOUT_MS,
AI_CONFLICT_MAX_TOKENS,
AI_CONFLICT_TEMPERATURE,
AI_RELEVANCE_MAX_TOKENS,
AI_RELEVANCE_TEMPERATURE,
VALID_CATEGORIES,
type MemoryCategory,
} from "../constants";
import type { MemoryConflict, MemoryRecord } from "../types";
let cachedClient: LMStudioClient | null = null;
function getClient(): LMStudioClient {
if (!cachedClient) cachedClient = new LMStudioClient();
return cachedClient;
}
async function callModelWithMeta(
prompt: string,
maxTokens: number,
temperature: number,
timeoutMs: number = AI_CALL_TIMEOUT_MS,
): Promise<{ text: string; charCount: number } | null> {
try {
const client = getClient();
const models = await Promise.race([
client.llm.listLoaded(),
new Promise<never>((_, rej) =>
setTimeout(() => rej(new Error("timeout")), timeoutMs),
),
]);
if (!Array.isArray(models) || models.length === 0) return null;
const model = await client.llm.model(models[0].identifier);
const stream = model.respond([{ role: "user", content: prompt }], {
maxTokens,
temperature,
});
let text = "";
for await (const chunk of stream) text += chunk.content ?? "";
const trimmed = text.trim();
if (!trimmed) return null;
return { text: trimmed, charCount: trimmed.length };
} catch {
cachedClient = null;
return null;
}
}
/** Simple model call (backward compat). */
async function callModel(
prompt: string,
maxTokens: number,
temperature: number,
timeoutMs?: number,
): Promise<string | null> {
const result = await callModelWithMeta(
prompt,
maxTokens,
temperature,
timeoutMs,
);
return result?.text ?? null;
}
/**
* A single candidate assessment from one of the K parallel samples.
*
* All fields are anchored to real memory IDs — there is no positional
* index mapping anywhere in this pipeline, eliminating the previous
* fragile index-to-ID translation.
*/
interface CandidateAssessment {
/** Memory IDs this assessment deemed relevant, in ranked order. */
selectedIds: string[];
/**
* Log-space verbalized confidence for this response.
*
* Paper §2.2.1: "VC(p) = Σ_t log(ν_t / 100) ≤ 0"
* In the paper this aggregates over T steps of a multi-step REPL trace.
* Since each K-sample here is a single-step response (t = 1), this
* simplifies to: logVC = log(ν / 100).
*
* ν ∈ (0, 100] → logVC ∈ (-∞, 0]. Values closer to 0 = higher confidence.
*/
logVC: number;
/**
* Approximate token count of this response.
*
* Paper §2.2.1: "Len(p) = Σ_t ℓ_t" — total tokens across all steps.
* Estimated here as charCount / 4 (standard GPT-family approximation).
* Shorter traces indicate more decisive, lower-uncertainty reasoning.
*/
tokenLen: number;
}
/**
* SRLM relevance re-ranking — faithful to arXiv:2603.15653.
*
* The paper's SRLM selects among K candidate context-interaction programs
* using three uncertainty signals. Here we adapt the same three-stage
* algorithm to select among candidate memories retrieved for a query.
* The unit of selection changes (memories instead of programs), but the
* mechanics of each signal follow the paper exactly.
*
* STAGE 1 — Self-Consistency filter (§2.2.1)
* Run K independent relevance assessments. For each memory, its
* "answer" is whether it is relevant (yes/no). The consistent set S
* contains memories where the plurality of K assessments agree it is
* relevant (strictly more than K/2 votes). Memories outside S are
* discarded before fine-grained scoring begins — matching the paper's
* treatment of self-consistency as a coarse filter, not a weight.
*
* STAGE 2 — Fine-grained scoring within S (§2.2.2)
* For each memory m ∈ S, aggregate across the assessments that selected it:
* avgLogVC(m) = mean of log(ν_k / 100) over those assessments [≤ 0]
* avgTokenLen(m) = mean token count of those assessments [> 0]
*
* STAGE 3 — Joint score and selection (§2.2.2)
* s(m) = avgLogVC(m) × avgTokenLen(m) [≤ 0, since VC ≤ 0 and Len > 0]
* argmax s(m) picks the memory with the least-negative product:
* highest confidence AND shortest (most decisive) trace.
*
* Finally, scores are normalized from [s_min, s_max] → [0, 1] for
* downstream blending with the TF-IDF composite score in engine.ts.
*
* @param context - Current user query / conversation context
* @param memories - Candidate memories to re-rank (capped at 12)
* @param K - Number of parallel candidate samples (default 3)
* @returns Map of memoryId → normalized SRLM score in [0, 1]
*/
export async function srlmRerank(
context: string,
memories: Array<{ id: string; content: string }>,
K: number = 3,
): Promise<Map<string, number>> {
const scores = new Map<string, number>();
if (memories.length === 0) return scores;
const capped = memories.slice(0, 12);
// Build the numbered list using actual IDs so the model can reference them
// and the parser can resolve them without any positional mapping.
const memList = capped
.map((m, i) => `[${i + 1}] (id=${m.id}) ${m.content}`)
.join("\n");
// Confidence elicitation prompt follows paper Appendix B.1:
// "appending a fixed structured instruction … requiring it to report a
// confidence score … in a standardized format {"confidence": ν}"
const prompt = `You are a memory relevance assessor. Given the user's current context and a list of memories, select which memories are relevant.
CONTEXT: "${context.slice(0, 600)}"
MEMORIES:
${memList}
INSTRUCTIONS:
1. List the IDs (shown in parentheses) of relevant memories in order of relevance
2. Only include memories genuinely useful for this context
3. Report your confidence using the exact JSON format below
OUTPUT FORMAT:
RELEVANT: <comma-separated memory IDs, e.g. abc123,def456>
\`\`\`json
{"confidence": <number between 0.001 and 100.000 with up to 3 decimal points>}
\`\`\`
If no memories are relevant, output:
RELEVANT: none
\`\`\`json
{"confidence": 95.000}
\`\`\`
Be precise and nuanced in your confidence assessment.`;
// Run K assessments in parallel with varied temperature for sampling diversity.
// Paper §2.2: "K candidate programs are independently sampled from πθ"
const assessmentPromises = Array.from({ length: K }, (_, k) =>
callModelWithMeta(
prompt,
AI_RELEVANCE_MAX_TOKENS,
0.3 + k * 0.2,
AI_CALL_TIMEOUT_MS,
).then((result): CandidateAssessment | null => {
if (!result) return null;
return parseCandidateAssessment(result.text, result.charCount, capped);
}),
);
const settled = await Promise.allSettled(assessmentPromises);
const assessments: CandidateAssessment[] = [];
for (const s of settled) {
if (s.status === "fulfilled" && s.value) assessments.push(s.value);
}
if (assessments.length === 0) return scores;
// ─── Stage 1: Self-Consistency Filter ────────────────────────────────────
// Paper §2.2.1: S = { p^(k) : out(p^(k)) = â } where â = argmax_a freq(a)
//
// Adapted: each memory's "answer" is yes/no (selected or not). The consistent
// set contains memories the majority of assessments agreed are relevant.
// This makes self-consistency a filter (as in the paper), not a weight.
const selectionCounts = new Map<string, number>();
for (const assessment of assessments) {
for (const id of assessment.selectedIds) {
selectionCounts.set(id, (selectionCounts.get(id) ?? 0) + 1);
}
}
const majorityThreshold = assessments.length / 2;
let consistentSet = new Set<string>(
[...selectionCounts.entries()]
.filter(([, count]) => count > majorityThreshold)
.map(([id]) => id),
);
// Graceful fallback: if no memory passes majority (e.g. K=1, or all memories
// appear in exactly one assessment), include anything selected at least once.
if (consistentSet.size === 0) {
consistentSet = new Set(selectionCounts.keys());
}
if (consistentSet.size === 0) return scores;
// ─── Stage 2: Aggregate fine-grained signals within the consistent set ───
// Paper §2.2.1 / §2.2.2: for each memory in S, compute VC and Len by
// averaging across the K assessments that selected it.
const logVCSums = new Map<string, number>();
const tokenLenSums = new Map<string, number>();
const includeCounts = new Map<string, number>();
for (const assessment of assessments) {
for (const id of assessment.selectedIds) {
if (!consistentSet.has(id)) continue; // ignore non-consistent memories
logVCSums.set(id, (logVCSums.get(id) ?? 0) + assessment.logVC);
tokenLenSums.set(id, (tokenLenSums.get(id) ?? 0) + assessment.tokenLen);
includeCounts.set(id, (includeCounts.get(id) ?? 0) + 1);
}
}
// ─── Stage 3: Joint Score ─────────────────────────────────────────────────
// Paper §2.2.2: s(p) = VC(p) · Len(p), p* = argmax_{p ∈ S} s(p)
//
// Since VC ≤ 0 and Len > 0, every joint score s ≤ 0.
// argmax selects the least-negative value: the memory with the highest
// confidence (VC closest to 0) combined with the shortest trace.
const jointScores = new Map<string, number>();
for (const id of consistentSet) {
const n = includeCounts.get(id) ?? 1;
// Fallback for logVC: log(0.5) ≈ -0.693, a neutral mid-confidence value
const avgLogVC = (logVCSums.get(id) ?? Math.log(0.5)) / n; // ≤ 0
const avgTokenLen = (tokenLenSums.get(id) ?? 100) / n; // > 0
jointScores.set(id, avgLogVC * avgTokenLen); // ≤ 0
}
// ─── Normalize [s_min, s_max] → [0, 1] for downstream blending ───────────
// s_max (closest to 0) maps to 1; s_min (most negative) maps to 0.
const allJoint = [...jointScores.values()];
const sMin = Math.min(...allJoint);
const sMax = Math.max(...allJoint);
const range = sMax - sMin;
for (const [id, s] of jointScores) {
// When all memories scored identically (range = 0), they all get 1.0.
scores.set(id, range > 0 ? (s - sMin) / range : 1.0);
}
return scores;
}
/**
* Parse one K-sample response into a CandidateAssessment.
*
* Resolves memory IDs directly from the response text — there is no
* positional index involved here, so no index-to-ID mapping is needed.
*
* Computes logVC in log-space per §2.2.1: logVC = log(ν / 100).
* Converts char count to approximate token count (÷ 4) per §2.2.1.
*/
function parseCandidateAssessment(
text: string,
charCount: number,
memories: Array<{ id: string; content: string }>,
): CandidateAssessment | null {
const relMatch = /RELEVANT:\s*(.+)/i.exec(text);
// Paper Appendix B.1 confidence format: {"confidence": <number>}
const confMatch = /\{"confidence":\s*([\d.]+)\}/i.exec(text);
if (!relMatch) return null;
const relStr = relMatch[1].trim().toLowerCase();
let selectedIds: string[] = [];
if (relStr !== "none") {
const validIds = new Set(memories.map((m) => m.id));
selectedIds = relStr
.split(/[,\s]+/)
.map((s) => s.trim())
.filter((id) => validIds.has(id));
}
// Paper §2.2.1: VC(p) = Σ_t log(ν_t / 100)
// ν ∈ (0.001, 100] → ν/100 ∈ (0, 1] → logVC = log(ν/100) ∈ (-∞, 0]
// Clamp to [0.001, 100] to avoid log(0) = -Infinity.
const rawConf = confMatch ? parseFloat(confMatch[1]) : 50;
const nu = Math.max(0.001, Math.min(100, rawConf));
const logVC = Math.log(nu / 100); // always ≤ 0
// Paper §2.2.1: Len = total tokens. Approximate: chars ÷ 4.
const tokenLen = Math.max(1, Math.round(charCount / 4));
return { selectedIds, logVC, tokenLen };
}
export interface ExtractedFact {
content: string;
category: MemoryCategory;
tags: string[];
confidence: number;
}
/**
* Extract structured facts from conversation text.
* Called by the preprocessor when AI extraction is enabled.
*/
export async function extractFacts(
conversationText: string,
existingSummary: string = "",
): Promise<ExtractedFact[]> {
const prompt = `You are a memory extraction system. Extract key facts, preferences, and information from this conversation that would be useful to remember for future conversations.
CONVERSATION:
${conversationText.slice(0, 2000)}
${existingSummary ? `ALREADY KNOWN:\n${existingSummary}\n\nOnly extract NEW information not already known.` : ""}
For each fact, output ONE line in this exact format:
FACT: <the information> | CATEGORY: <${VALID_CATEGORIES.join("/")}> | TAGS: <comma-separated tags> | CONFIDENCE: <0.0-1.0>
Rules:
- Extract only clearly stated facts, not guesses or ambiguous statements
- CONFIDENCE: 1.0 = explicitly stated, 0.7 = strongly implied, 0.5 = inferred
- Keep each fact concise (one sentence)
- Maximum 5 facts per extraction
- If no useful facts to extract, output: NONE
OUTPUT:`;
const raw = await callModel(
prompt,
AI_EXTRACT_MAX_TOKENS,
AI_EXTRACT_TEMPERATURE,
);
if (!raw || /^NONE$/im.test(raw.trim())) return [];
const facts: ExtractedFact[] = [];
for (const line of raw.split("\n")) {
const trimmed = line.trim();
if (!trimmed.startsWith("FACT:")) continue;
try {
const factMatch = /FACT:\s*(.+?)\s*\|/.exec(trimmed);
const catMatch = /CATEGORY:\s*(\w+)/i.exec(trimmed);
const tagMatch = /TAGS:\s*(.+?)\s*\|/.exec(trimmed);
const confMatch = /CONFIDENCE:\s*([\d.]+)/i.exec(trimmed);
if (!factMatch) continue;
const content = factMatch[1].trim();
if (content.length < 5 || content.length > 500) continue;
const rawCat = catMatch?.[1]?.toLowerCase() ?? "note";
const category = VALID_CATEGORIES.includes(rawCat as MemoryCategory)
? (rawCat as MemoryCategory)
: "note";
const tags = (tagMatch?.[1] ?? "")
.split(",")
.map((t) => t.trim().toLowerCase())
.filter((t) => t.length >= 2 && t.length <= 50)
.slice(0, 5);
const confidence = confMatch
? Math.max(0, Math.min(1, parseFloat(confMatch[1])))
: 0.7;
facts.push({ content, category, tags, confidence });
} catch {
continue;
}
}
return facts;
}
export async function detectConflicts(
newContent: string,
existingMemories: MemoryRecord[],
): Promise<MemoryConflict[]> {
if (existingMemories.length === 0) return [];
const existing = existingMemories
.slice(0, 8)
.map((m, i) => `[${i + 1}] (id=${m.id}) ${m.content}`)
.join("\n");
const prompt = `You are a memory conflict detector. Check if this NEW memory conflicts with any EXISTING memories.
NEW MEMORY: "${newContent}"
EXISTING MEMORIES:
${existing}
For each conflict found, output ONE line:
CONFLICT: <existing_index> | TYPE: <contradiction/update/duplicate> | ACTION: <keep_both/supersede/skip>
Rules:
- "contradiction": memories disagree on a fact → keep_both (let user decide)
- "update": new memory is a newer version of the same fact → supersede
- "duplicate": essentially the same information → skip
- If no conflicts: output NONE
- Max 3 conflicts
OUTPUT:`;
const raw = await callModel(
prompt,
AI_CONFLICT_MAX_TOKENS,
AI_CONFLICT_TEMPERATURE,
);
if (!raw || /^NONE$/im.test(raw.trim())) return [];
const conflicts: MemoryConflict[] = [];
for (const line of raw.split("\n")) {
const trimmed = line.trim();
if (!trimmed.startsWith("CONFLICT:")) continue;
try {
const idxMatch = /CONFLICT:\s*(\d+)/i.exec(trimmed);
const typeMatch = /TYPE:\s*(\w+)/i.exec(trimmed);
const actionMatch = /ACTION:\s*([\w_]+)/i.exec(trimmed);
if (!idxMatch) continue;
const idx = parseInt(idxMatch[1], 10) - 1;
if (idx < 0 || idx >= existingMemories.length) continue;
const mem = existingMemories[idx];
const conflictType =
(["contradiction", "update", "duplicate"] as const).find(
(t) => typeMatch?.[1]?.toLowerCase() === t,
) ?? "contradiction";
const resolution =
(["keep_both", "supersede", "skip"] as const).find(
(r) => actionMatch?.[1]?.toLowerCase() === r,
) ?? "keep_both";
conflicts.push({
existingId: mem.id,
existingContent: mem.content,
newContent,
conflictType,
resolution,
});
} catch {
continue;
}
}
return conflicts;
}