Project Files
src / promptPreprocessor.ts
import {
text,
type ChatMessage,
type FileHandle,
type PromptPreprocessorController,
} from "@lmstudio/sdk";
import { configSchematics } from "./config";
import { MemoryService } from "./memory/MemoryService";
import { getEmbeddingModel } from "./retrieval/embeddingCache";
import { expandQueries } from "./retrieval/queryExpansion";
import { hybridScore } from "./retrieval/hybridSearch";
import { dedupeResults, type RetrievedItem } from "./retrieval/dedupe";
import { compressChunk } from "./retrieval/compression";
import { fitToBudget, type BudgetedChunk } from "./retrieval/tokenBudget";
import { rerankChunks } from "./retrieval/reranker";
import { sanitizePrompt } from "./security/promptInjection";
import { buildDocumentContext, RAG_AUGMENTED_MARKER } from "./templates/document_context";
import { normalizeWhitespace, stableHash, truncate, TEXT_EXTS } from "./utils/text";
type RetrievalCacheEntry = {
value: RetrievedItem[];
expiresAt: number;
};
const retrievalCache = new Map<string, RetrievalCacheEntry>();
const CACHE_TTL_MS = 300_000;
const CACHE_MAX_SIZE = 64;
const MAX_RETRIEVAL_FILES = 8;
const MEMORY_MIN_PROMPT_LENGTH = 20;
const fileSignatureCache = new WeakMap<FileHandle, string>();
const fileTimestampCache = new WeakMap<FileHandle, number | null>();
const fileTextLikeCache = new WeakMap<FileHandle, boolean>();
const fileSelectionCache = new Map<string, FileHandle[]>();
const FILE_SELECTION_CACHE_MAX = 128;
const RETRIEVAL_TIMEOUT_MS = 8_000;
const MAX_CONSECUTIVE_FAILURES = 3;
const FAILURE_RESET_MS = 30_000;
let consecutiveFailures = 0;
let lastFailureTime = 0;
function withTimeout<T>(promise: Promise<T>, ms: number): Promise<T> {
let timer: ReturnType<typeof setTimeout> | undefined;
const race = Promise.race([
promise,
new Promise<never>((_, reject) => {
timer = setTimeout(
() => reject(new Error(`retrieval timed out after ${ms}ms`)),
ms,
);
}),
]);
return race.finally(() => {
if (timer) clearTimeout(timer);
}) as Promise<T>;
}
function isAbortError(error: unknown): boolean {
return (
!!error &&
typeof error === "object" &&
"name" in error &&
(error as { name?: string }).name === "AbortError"
);
}
function pruneRetrievalCache(): void {
const now = Date.now();
let validCount = 0;
for (const [, entry] of retrievalCache.entries()) {
if (entry.expiresAt > now) validCount++;
}
if (validCount <= CACHE_MAX_SIZE) return;
const toRemove = validCount - CACHE_MAX_SIZE;
let removed = 0;
for (const [key, entry] of retrievalCache.entries()) {
if (removed >= toRemove) break;
if (entry.expiresAt <= now || validCount > CACHE_MAX_SIZE) {
retrievalCache.delete(key);
if (entry.expiresAt > now) validCount--;
removed++;
}
}
}
function getFileTimestamp(file: FileHandle): number | null {
const cached = fileTimestampCache.get(file);
if (cached !== undefined) return cached;
const record = file as Record<string, unknown>;
const candidateKeys = [
"lastModified",
"updatedAt",
"modifiedAt",
"mtime",
"timestamp",
"createdAt",
"created",
];
let result: number | null = null;
for (const key of candidateKeys) {
if (result !== null) break;
const value = record[key];
if (typeof value === "number" && Number.isFinite(value)) {
result = value;
} else if (typeof value === "string") {
const parsed = Date.parse(value);
if (!Number.isNaN(parsed)) result = parsed;
}
}
fileTimestampCache.set(file, result);
return result;
}
function compareFileByTimestamp(a: FileHandle, b: FileHandle): number {
const aTime = getFileTimestamp(a);
const bTime = getFileTimestamp(b);
if (aTime !== null && bTime !== null) return bTime - aTime;
if (aTime !== null) return -1;
if (bTime !== null) return 1;
return 0;
}
function isTextLikeFile(file: FileHandle): boolean {
const cached = fileTextLikeCache.get(file);
if (cached !== undefined) return cached;
const record = file as Record<string, unknown>;
const type = String(record.type ?? "").toLowerCase();
const name = String(
record.name ??
record.fileName ??
record.filename ??
record.originalName ??
record.path ??
"",
).toLowerCase();
const result =
(type && type !== "image") ||
TEXT_EXTS.test(name) ||
Boolean(record.content || record.text || record.document || record.snippet);
fileTextLikeCache.set(file, result);
return result;
}
function selectRetrievalFiles(files: FileHandle[]): FileHandle[] {
const cacheKey = files.map(fileSignature).join(";");
const cached = fileSelectionCache.get(cacheKey);
if (cached) return cached;
const validFiles = files.filter(isTextLikeFile);
const result =
validFiles.length === 0
? files.slice(0, MAX_RETRIEVAL_FILES)
: validFiles.length <= MAX_RETRIEVAL_FILES
? validFiles
: [...validFiles].sort(compareFileByTimestamp).slice(0, MAX_RETRIEVAL_FILES);
if (fileSelectionCache.size >= FILE_SELECTION_CACHE_MAX) {
const oldest = fileSelectionCache.keys().next().value;
if (oldest !== undefined) fileSelectionCache.delete(oldest);
}
fileSelectionCache.set(cacheKey, result);
return result;
}
function toText(value: unknown): string {
if (typeof value === "string") return value;
if (!value || typeof value !== "object") return "";
const obj = value as Record<string, unknown>;
return String(
obj.content ??
obj.text ??
obj.chunk ??
obj.document ??
obj.snippet ??
obj.value ??
"",
);
}
function toScore(value: unknown): number {
if (!value || typeof value !== "object") return 0;
const obj = value as Record<string, unknown>;
const raw = obj.score ?? obj.affinity ?? obj.relevance ?? obj.similarity ?? 0;
const n = typeof raw === "number" ? raw : Number(raw);
return Number.isFinite(n) ? n : 0;
}
function toCitation(value: unknown): string {
if (!value || typeof value !== "object") return "";
const obj = value as Record<string, unknown>;
return String(
obj.fileName ??
obj.filename ??
obj.originalName ??
obj.sourceName ??
obj.path ??
"",
);
}
function normalizeResult(value: unknown): RetrievedItem {
const textValue = truncate(toText(value).trim(), 5000);
const citation = toCitation(value);
const score = toScore(value);
return {
text: textValue,
score,
citation: citation || undefined,
sourceName: citation || undefined,
confidence: Math.max(0, Math.min(1, score)),
};
}
function fileSignature(file: FileHandle): string {
const cached = fileSignatureCache.get(file);
if (cached) return cached;
const record = file as Record<string, unknown>;
const parts = [
String(record.type ?? ""),
String(
record.name ??
record.fileName ??
record.filename ??
record.originalName ??
record.path ??
"",
),
String(record.id ?? record.uuid ?? record.hash ?? record.source ?? ""),
].filter(Boolean);
const signature = parts.length > 0
? parts.join("|")
: stableHash(
Object.keys(record)
.sort()
.map(key => `${key}:${String(record[key])}`)
.join("|"),
);
fileSignatureCache.set(file, signature);
return signature;
}
function buildCacheKey(
prompt: string,
files: FileHandle[],
limit: number,
multiQueryCount: number,
threshold: number,
): string {
const fileKey = files.map(fileSignature).join(";");
return stableHash(
[prompt, fileKey, String(limit), String(multiQueryCount), String(threshold)].join(
"\u0000",
),
);
}
async function retrieveAcrossQueries(
ctl: PromptPreprocessorController,
prompt: string,
files: FileHandle[],
) {
const config = ctl.getPluginConfig(configSchematics);
const retrievalLimit = Math.max(
1,
Math.min(24, Number(config.get("retrievalLimit")) || 6),
);
const multiQueryCount = Math.max(
1,
Math.min(6, Number(config.get("multiQueryCount")) || 2),
);
const threshold = Math.max(
0,
Math.min(1, Number(config.get("retrievalAffinityThreshold")) || 0.55),
);
const cacheKey = buildCacheKey(prompt, files, retrievalLimit, multiQueryCount, threshold);
const cached = retrievalCache.get(cacheKey);
if (cached && cached.expiresAt > Date.now()) {
return cached.value;
}
const inflight = inFlightRetrievals.get(cacheKey);
if (inflight !== undefined) return inflight;
const promise = (async () => {
const embeddingModel = await getEmbeddingModel(ctl);
const queries = expandQueries(prompt, multiQueryCount);
const queryResults = await Promise.all(
queries.map(async (query) => {
try {
const items = (await withTimeout(
ctl.client.files.retrieve(query, files, {
embeddingModel,
limit: retrievalLimit,
signal: ctl.abortSignal,
}),
RETRIEVAL_TIMEOUT_MS,
)) as unknown[];
return { query, items };
} catch (error) {
if (isAbortError(error)) throw error;
return { query, items: [] as unknown[] };
}
}),
);
const results: RetrievedItem[] = [];
for (const { items } of queryResults) {
for (const item of items as unknown[]) {
const normalized = normalizeResult(item);
if (!normalized.text) continue;
const scored = { ...normalized, score: hybridScore(query, normalized.text, normalized.score) };
if (scored.score >= threshold) {
results.push(scored);
}
}
}
const deduped = dedupeResults(results);
const ranked =
deduped.length <= 4
? deduped
: rerankChunks(deduped, prompt);
pruneRetrievalCache();
retrievalCache.set(cacheKey, {
value: ranked,
expiresAt: Date.now() + CACHE_TTL_MS,
});
return ranked;
})();
inFlightRetrievals.set(cacheKey, promise);
promise.finally(() => inFlightRetrievals.delete(cacheKey));
return promise;
}
const inFlightRetrievals = new Map<string, Promise<RetrievedItem[]>>();
function adaptiveBudget(prompt: string, selectedSources: number): number {
const length = normalizeWhitespace(prompt).length;
const base = selectedSources > 6 ? 12_000 : 10_000;
const growth = Math.min(8_000, Math.floor(length * 0.4));
return Math.max(6_000, Math.min(18_000, base + growth));
}
export async function preprocess(
ctl: PromptPreprocessorController,
userMessage: ChatMessage,
) {
const rawPrompt = userMessage.getText();
if (!rawPrompt.trim()) return userMessage;
if (rawPrompt.includes(RAG_AUGMENTED_MARKER)) return userMessage;
if (consecutiveFailures >= MAX_CONSECUTIVE_FAILURES) {
if (Date.now() - lastFailureTime < FAILURE_RESET_MS) return userMessage;
consecutiveFailures = 0;
}
const config = ctl.getPluginConfig(configSchematics);
const prompt = config.get("enablePromptInjectionProtection")
? sanitizePrompt(rawPrompt)
: rawPrompt;
const normalizedPrompt = normalizeWhitespace(prompt);
const history = await ctl.pullHistory();
history.append(userMessage);
const files = history
.getAllFiles(ctl.client)
.filter((file: FileHandle) => file.type !== "image");
if (!files.length) return userMessage;
const retrievalFiles = selectRetrievalFiles(files);
let retrieved: RetrievedItem[] = [];
try {
retrieved = await retrieveAcrossQueries(ctl, prompt, retrievalFiles);
consecutiveFailures = 0;
} catch (error) {
if (isAbortError(error)) throw error;
consecutiveFailures++;
lastFailureTime = Date.now();
return userMessage;
}
if (!retrieved.length) return userMessage;
const selectedSourceCount = Math.min(
retrieved.length,
Math.max(1, Math.min(5, Number(config.get("selectedSourceCount")) || 5)),
);
const useCompression = Boolean(config.get("enableCompression"));
const enriched: BudgetedChunk[] = [];
const sliceEnd = Math.min(selectedSourceCount, retrieved.length);
for (let i = 0; i < sliceEnd; i++) {
const item = retrieved[i];
enriched.push({
...item,
text: useCompression && retrieved.length >= 3 ? compressChunk(item.text) : item.text,
});
}
const selected = fitToBudget(enriched, adaptiveBudget(prompt, selectedSourceCount));
if (!selected.length) return userMessage;
const activeSources = new Map<string, number>();
for (let index = 0; index < selected.length; index++) {
const item = selected[index];
const key = item.sourceName || item.citation || `chunk-${index + 1}`;
if (!activeSources.has(key)) activeSources.set(key, index + 1);
}
let citations = "";
let idx = 0;
for (const [name, num] of activeSources.entries()) {
if (idx > 0) citations += "\n";
citations += `[${num}] ${name}`;
idx++;
}
let memory = "";
if (config.get("enableMemory")) {
const topic = MemoryService.extractTopic(rawPrompt);
memory = MemoryService.retrieve(topic);
if (topic && rawPrompt.length >= MEMORY_MIN_PROMPT_LENGTH) {
MemoryService.save(topic, truncate(rawPrompt, 240));
}
}
let context = "";
for (let i = 0; i < selected.length; i++) {
if (i > 0) context += "\n\n---\n\n";
context += selected[i].text;
}
return text(
buildDocumentContext({
citations,
context: context || "(no retrieved context)",
memory,
prompt,
}),
);
}