Project Files
src / search / semanticSearch.ts
/**
* Semantic Search for Draw Things Index
*
* Uses pre-computed embeddings for cross-language and conceptual search.
* Finds semantically similar prompts even without keyword overlap.
*/
import { cosineSimilarity, findTopK, similarityToScore } from '../embeddings';
import type { IndexedGeneration, GenerationMatch, MatchType } from '../types';
export interface SemanticSearchOptions {
/** Maximum number of semantic matches to return */
maxResults?: number;
/** Minimum cosine similarity threshold (0.0 - 1.0) */
minSimilarity?: number;
/** Prompt keys to exclude (already found by keyword search) */
excludePrompts?: Set<string>;
}
const DEFAULT_OPTIONS: Required<Omit<SemanticSearchOptions, 'excludePrompts'>> = {
maxResults: 5,
minSimilarity: 0.6,
};
export interface SemanticMatch {
generation: IndexedGeneration;
similarity: number;
score: number; // 0-100 normalized
}
/**
* Search for semantically similar generations using embeddings
*
* @param queryEmbedding Pre-computed embedding for the search query
* @param generations Indexed generations with embeddings
* @param options Search options
* @returns Semantic matches sorted by similarity
*/
export function semanticSearch(
queryEmbedding: number[],
generations: IndexedGeneration[],
options: SemanticSearchOptions = {}
): SemanticMatch[] {
const opts = { ...DEFAULT_OPTIONS, ...options };
// Filter to generations with embeddings
const withEmbeddings = generations.filter(g =>
g.promptEmbedding &&
g.promptEmbedding.length > 0 &&
// Exclude prompts already found by keyword search
!(opts.excludePrompts?.has(g.prompt))
);
if (withEmbeddings.length === 0) {
return [];
}
// Find top-k most similar
const results = findTopK(
queryEmbedding,
withEmbeddings,
g => g.promptEmbedding,
opts.maxResults * 2, // Fetch extra for filtering
opts.minSimilarity
);
// Convert to SemanticMatch and limit
return results
.slice(0, opts.maxResults)
.map(({ item, similarity }) => ({
generation: item,
similarity,
score: similarityToScore(similarity),
}));
}
/**
* Convert a SemanticMatch to GenerationMatch format
* (for integration with existing search result structure)
*/
export function semanticMatchToGenerationMatch(
match: SemanticMatch
): GenerationMatch {
const gen = match.generation;
// Normalize LoRAs to string array
const loras = gen.loras?.map(l => typeof l === 'string' ? l : l.model);
return {
prompt: gen.prompt,
negativePrompt: gen.negativePrompt,
model: gen.model,
loras,
seed: gen.seed,
steps: gen.steps,
cfgScale: gen.cfgScale,
width: gen.width,
height: gen.height,
imagePaths: gen.imagePaths || [],
httpPreviewUrls: gen.httpPreviewUrls,
sourceInfo: gen.sourceInfo,
timestamp: gen.timestamp,
matchScore: match.score,
matchType: 'semantic' as MatchType,
// Note: matchedTerms not applicable for semantic search
};
}
/**
* Check if generations have embeddings available
*/
export function hasEmbeddings(generations: IndexedGeneration[]): boolean {
return generations.some(g => g.promptEmbedding && g.promptEmbedding.length > 0);
}
/**
* Get embedding statistics for debugging
*/
export function getEmbeddingStats(generations: IndexedGeneration[]): {
total: number;
withEmbeddings: number;
embeddingModel: string | null;
dimension: number | null;
} {
const withEmbeddings = generations.filter(g => g.promptEmbedding?.length);
const firstEmbedding = withEmbeddings[0]?.promptEmbedding;
return {
total: generations.length,
withEmbeddings: withEmbeddings.length,
embeddingModel: withEmbeddings[0]?.embeddingModel || null,
dimension: firstEmbedding?.length || null,
};
}