src / processing / ai.ts
/**
* @file processing/ai.ts
* AI-powered memory processing using the loaded LM Studio model.
*
* Implements SRLM-style uncertainty-aware self-reflection for memory retrieval:
*
* FROM THE PAPER (arXiv:2603.15653):
* SRLM uses three complementary uncertainty signals:
* 1. Self-Consistency — sample K candidates, keep the plurality answer set
* 2. Verbalized Confidence (VC) — model self-reports confidence per step,
* aggregated in log-space: VC(p) = Σ log(ν_t / 100)
* 3. Trace Length (Len) — shorter traces = more confident reasoning
* Joint score: s(p) = VC(p) · Len(p), select p* = argmax s(p)
*
* ADAPTED FOR MEMORY RETRIEVAL:
* Instead of selecting context-interaction programs, we select which
* retrieved memories are most relevant to the current query.
*
* 1. Self-Consistency: Ask the model K times "which memories are relevant?"
* Keep memories that appear in the plurality of responses.
* 2. Verbalized Confidence: Each response includes a confidence score.
* Memories endorsed with higher confidence get boosted.
* 3. Trace Length: Shorter responses (more decisive) are weighted higher.
* Joint: memories are re-ranked by the product of normalized VC and inverse Len.
*
* Also provides: fact extraction, conflict detection.
* All calls are optional (graceful fallback if no model loaded).
* All calls have strict timeouts to never block the chat flow.
*/
import { LMStudioClient } from "@lmstudio/sdk";
import {
AI_EXTRACT_MAX_TOKENS,
AI_EXTRACT_TEMPERATURE,
AI_CALL_TIMEOUT_MS,
AI_CONFLICT_MAX_TOKENS,
AI_CONFLICT_TEMPERATURE,
AI_RELEVANCE_MAX_TOKENS,
AI_RELEVANCE_TEMPERATURE,
VALID_CATEGORIES,
type MemoryCategory,
} from "../constants";
import type { MemoryConflict, MemoryRecord } from "../types";
let cachedClient: LMStudioClient | null = null;
function getClient(): LMStudioClient {
if (!cachedClient) cachedClient = new LMStudioClient();
return cachedClient;
}
async function callModelWithMeta(
prompt: string,
maxTokens: number,
temperature: number,
timeoutMs: number = AI_CALL_TIMEOUT_MS,
): Promise<{ text: string; charCount: number } | null> {
try {
const client = getClient();
const models = await Promise.race([
client.llm.listLoaded(),
new Promise<never>((_, rej) =>
setTimeout(() => rej(new Error("timeout")), timeoutMs),
),
]);
if (!Array.isArray(models) || models.length === 0) return null;
const model = await client.llm.model(models[0].identifier);
const stream = model.respond([{ role: "user", content: prompt }], {
maxTokens,
temperature,
});
let text = "";
for await (const chunk of stream) text += chunk.content ?? "";
const trimmed = text.trim();
if (!trimmed) return null;
return { text: trimmed, charCount: trimmed.length };
} catch {
cachedClient = null;
return null;
}
}
/** Simple model call (backward compat). */
async function callModel(
prompt: string,
maxTokens: number,
temperature: number,
timeoutMs?: number,
): Promise<string | null> {
const result = await callModelWithMeta(
prompt,
maxTokens,
temperature,
timeoutMs,
);
return result?.text ?? null;
}
/** A single candidate response from one K-sample. */
interface CandidateResponse {
/** Memory IDs this candidate deemed relevant, in order. */
selectedIds: string[];
/** Verbalized confidence score (0–100). */
verbalizedConfidence: number;
/** Response length in characters (proxy for trace length). */
traceLength: number;
}
/**
* SRLM-inspired relevance re-ranking with three uncertainty signals.
*
* Generates K candidate relevance assessments in parallel, then:
* 1. Self-Consistency: count how many candidates selected each memory
* 2. Verbalized Confidence: average confidence for memories that were selected
* 3. Trace Length: weight by inverse of response length (shorter = more decisive)
*
* Joint score per memory:
* score(m) = consistency(m) × avgVC(m) × (1 / avgLen(m))
*
* @param context - Current user query / conversation context
* @param memories - Candidate memories to re-rank
* @param K - Number of parallel candidate samples (default 3 for speed)
* @returns Map of memoryId → SRLM relevance score (0–1 normalized)
*/
export async function srlmRelevanceRerank(
context: string,
memories: Array<{ id: string; content: string }>,
K: number = 3,
): Promise<Map<string, number>> {
const scores = new Map<string, number>();
if (memories.length === 0) return scores;
const memList = memories
.slice(0, 12)
.map((m, i) => `[${i + 1}] ${m.content}`)
.join("\n");
const prompt = `You are a memory relevance assessor. Given the user's current context and a list of memories, select which memories are relevant.
CONTEXT: "${context.slice(0, 600)}"
MEMORIES:
${memList}
INSTRUCTIONS:
1. List the indices of relevant memories in order of relevance (most relevant first)
2. Only include memories that are genuinely useful for this context
3. Report your confidence in this assessment
OUTPUT FORMAT (exactly):
RELEVANT: <comma-separated indices, e.g. 3,1,5>
CONFIDENCE: <number 0-100>
If no memories are relevant, output:
RELEVANT: none
CONFIDENCE: 95`;
const candidatePromises: Promise<CandidateResponse | null>[] = [];
for (let k = 0; k < K; k++) {
candidatePromises.push(
callModelWithMeta(
prompt,
AI_RELEVANCE_MAX_TOKENS,
0.4 + k * 0.15,
AI_CALL_TIMEOUT_MS,
).then((result) => {
if (!result) return null;
return parseCandidateResponse(
result.text,
result.charCount,
memories.length,
);
}),
);
}
const settled = await Promise.allSettled(candidatePromises);
const candidates: CandidateResponse[] = [];
for (const s of settled) {
if (s.status === "fulfilled" && s.value) candidates.push(s.value);
}
if (candidates.length === 0) return scores;
const consistencyCount = new Map<string, number>();
const vcSums = new Map<string, number>();
const vcCounts = new Map<string, number>();
const lenSums = new Map<string, number>();
const lenCounts = new Map<string, number>();
for (const candidate of candidates) {
for (const id of candidate.selectedIds) {
consistencyCount.set(id, (consistencyCount.get(id) ?? 0) + 1);
vcSums.set(id, (vcSums.get(id) ?? 0) + candidate.verbalizedConfidence);
vcCounts.set(id, (vcCounts.get(id) ?? 0) + 1);
lenSums.set(id, (lenSums.get(id) ?? 0) + candidate.traceLength);
lenCounts.set(id, (lenCounts.get(id) ?? 0) + 1);
}
}
// score(m) = (consistency/K) × (avgVC/100) × (1 / log(avgLen))
const rawScores = new Map<string, number>();
let maxScore = 0;
for (const [id, count] of consistencyCount) {
const consistency = count / candidates.length; // 0–1
const avgVC = (vcSums.get(id) ?? 50) / (vcCounts.get(id) ?? 1) / 100; // 0–1
const avgLen = (lenSums.get(id) ?? 500) / (lenCounts.get(id) ?? 1);
const inverseLenScore = 1 / Math.log(Math.max(avgLen, 10)); // shorter = higher
const joint = consistency * avgVC * inverseLenScore;
rawScores.set(id, joint);
if (joint > maxScore) maxScore = joint;
}
// norm to 0–1
if (maxScore > 0) {
for (const [id, raw] of rawScores) {
scores.set(id, raw / maxScore);
}
}
return scores;
}
/** Parse a single candidate response into structured form. */
function parseCandidateResponse(
text: string,
charCount: number,
maxIdx: number,
): CandidateResponse | null {
const relMatch = /RELEVANT:\s*(.+)/i.exec(text);
const confMatch = /CONFIDENCE:\s*([\d.]+)/i.exec(text);
if (!relMatch) return null;
const relStr = relMatch[1].trim().toLowerCase();
let selectedIds: string[] = [];
if (relStr !== "none") {
const indices = relStr
.split(/[,\s]+/)
.map((s) => parseInt(s, 10) - 1)
.filter((i) => i >= 0 && i < maxIdx);
selectedIds = indices.map((i) => String(i));
}
const confidence = confMatch
? Math.max(0, Math.min(100, parseFloat(confMatch[1])))
: 50;
return {
selectedIds,
verbalizedConfidence: confidence,
traceLength: charCount,
};
}
/**
* Public wrapper that maps SRLM scores back to memory IDs.
* This is what the retrieval engine calls.
*/
export async function srlmRerank(
context: string,
memories: Array<{ id: string; content: string }>,
K: number = 3,
): Promise<Map<string, number>> {
const indexedMems = memories.slice(0, 12);
const indexScores = await srlmRelevanceRerank(context, indexedMems, K);
const idScores = new Map<string, number>();
for (const [indexStr, score] of indexScores) {
const idx = parseInt(indexStr, 10);
if (idx >= 0 && idx < indexedMems.length) {
idScores.set(indexedMems[idx].id, score);
}
}
return idScores;
}
export interface ExtractedFact {
content: string;
category: MemoryCategory;
tags: string[];
confidence: number;
}
/**
* Extract structured facts from conversation text.
* Called by the preprocessor when AI extraction is enabled.
*/
export async function extractFacts(
conversationText: string,
existingSummary: string = "",
): Promise<ExtractedFact[]> {
const prompt = `You are a memory extraction system. Extract key facts, preferences, and information from this conversation that would be useful to remember for future conversations.
CONVERSATION:
${conversationText.slice(0, 2000)}
${existingSummary ? `ALREADY KNOWN:\n${existingSummary}\n\nOnly extract NEW information not already known.` : ""}
For each fact, output ONE line in this exact format:
FACT: <the information> | CATEGORY: <${VALID_CATEGORIES.join("/")}> | TAGS: <comma-separated tags> | CONFIDENCE: <0.0-1.0>
Rules:
- Extract only clearly stated facts, not guesses or ambiguous statements
- CONFIDENCE: 1.0 = explicitly stated, 0.7 = strongly implied, 0.5 = inferred
- Keep each fact concise (one sentence)
- Maximum 5 facts per extraction
- If no useful facts to extract, output: NONE
OUTPUT:`;
const raw = await callModel(
prompt,
AI_EXTRACT_MAX_TOKENS,
AI_EXTRACT_TEMPERATURE,
);
if (!raw || /^NONE$/im.test(raw.trim())) return [];
const facts: ExtractedFact[] = [];
for (const line of raw.split("\n")) {
const trimmed = line.trim();
if (!trimmed.startsWith("FACT:")) continue;
try {
const factMatch = /FACT:\s*(.+?)\s*\|/.exec(trimmed);
const catMatch = /CATEGORY:\s*(\w+)/i.exec(trimmed);
const tagMatch = /TAGS:\s*(.+?)\s*\|/.exec(trimmed);
const confMatch = /CONFIDENCE:\s*([\d.]+)/i.exec(trimmed);
if (!factMatch) continue;
const content = factMatch[1].trim();
if (content.length < 5 || content.length > 500) continue;
const rawCat = catMatch?.[1]?.toLowerCase() ?? "note";
const category = VALID_CATEGORIES.includes(rawCat as MemoryCategory)
? (rawCat as MemoryCategory)
: "note";
const tags = (tagMatch?.[1] ?? "")
.split(",")
.map((t) => t.trim().toLowerCase())
.filter((t) => t.length >= 2 && t.length <= 50)
.slice(0, 5);
const confidence = confMatch
? Math.max(0, Math.min(1, parseFloat(confMatch[1])))
: 0.7;
facts.push({ content, category, tags, confidence });
} catch {
continue;
}
}
return facts;
}
export async function detectConflicts(
newContent: string,
existingMemories: MemoryRecord[],
): Promise<MemoryConflict[]> {
if (existingMemories.length === 0) return [];
const existing = existingMemories
.slice(0, 8)
.map((m, i) => `[${i + 1}] (id=${m.id}) ${m.content}`)
.join("\n");
const prompt = `You are a memory conflict detector. Check if this NEW memory conflicts with any EXISTING memories.
NEW MEMORY: "${newContent}"
EXISTING MEMORIES:
${existing}
For each conflict found, output ONE line:
CONFLICT: <existing_index> | TYPE: <contradiction/update/duplicate> | ACTION: <keep_both/supersede/skip>
Rules:
- "contradiction": memories disagree on a fact → keep_both (let user decide)
- "update": new memory is a newer version of the same fact → supersede
- "duplicate": essentially the same information → skip
- If no conflicts: output NONE
- Max 3 conflicts
OUTPUT:`;
const raw = await callModel(
prompt,
AI_CONFLICT_MAX_TOKENS,
AI_CONFLICT_TEMPERATURE,
);
if (!raw || /^NONE$/im.test(raw.trim())) return [];
const conflicts: MemoryConflict[] = [];
for (const line of raw.split("\n")) {
const trimmed = line.trim();
if (!trimmed.startsWith("CONFLICT:")) continue;
try {
const idxMatch = /CONFLICT:\s*(\d+)/i.exec(trimmed);
const typeMatch = /TYPE:\s*(\w+)/i.exec(trimmed);
const actionMatch = /ACTION:\s*([\w_]+)/i.exec(trimmed);
if (!idxMatch) continue;
const idx = parseInt(idxMatch[1], 10) - 1;
if (idx < 0 || idx >= existingMemories.length) continue;
const mem = existingMemories[idx];
const conflictType =
(["contradiction", "update", "duplicate"] as const).find(
(t) => typeMatch?.[1]?.toLowerCase() === t,
) ?? "contradiction";
const resolution =
(["keep_both", "supersede", "skip"] as const).find(
(r) => actionMatch?.[1]?.toLowerCase() === r,
) ?? "keep_both";
conflicts.push({
existingId: mem.id,
existingContent: mem.content,
newContent,
conflictType,
resolution,
});
} catch {
continue;
}
}
return conflicts;
}