src / generator.ts
import { type Chat, type GeneratorController, type InferParsedConfig } from "@lmstudio/sdk";
import OpenAI from "openai";
import {
type ChatCompletionMessageParam,
type ChatCompletionMessageToolCall,
type ChatCompletionTool,
type ChatCompletionToolMessageParam,
} from "openai/resources/index";
import { configSchematics, globalConfigSchematics } from "./config";
/* -------------------------------------------------------------------------- */
/* Types */
/* -------------------------------------------------------------------------- */
type ToolCallState = {
id: string;
name: string | null;
index: number;
arguments: string;
};
/* -------------------------------------------------------------------------- */
/* Build helpers */
/* -------------------------------------------------------------------------- */
/** Build a pre-configured OpenAI client. */
function createOpenAI(
globalConfig: InferParsedConfig<typeof globalConfigSchematics>,
model: string,
) {
const overrideBaseUrl = globalConfig.get("overrideBaseUrl");
const thirdPartyApiKey = globalConfig.get("thirdPartyApiKey");
// Detect model provider based on model name
const isOpenAIModel = model.startsWith("gpt-") || model.startsWith("o1-");
const isMistralModel = model.startsWith("mistral-") || model.startsWith("codestral-") ||
model.startsWith("open-mistral-") || model.startsWith("open-codestral-") ||
model.includes("mistral");
const isAnthropicModel = model.startsWith("claude-");
// Determine base URL
let baseURL: string;
if (overrideBaseUrl) {
baseURL = overrideBaseUrl;
} else if (isMistralModel) {
baseURL = "https://api.mistral.ai/v1";
} else if (isOpenAIModel) {
baseURL = "https://api.openai.com/v1";
} else if (isAnthropicModel) {
baseURL = "https://api.anthropic.com/v1";
} else {
// Default to OpenAI for unknown models
baseURL = "https://api.openai.com/v1";
}
// Determine API key with proper priority
let apiKey: string;
if (overrideBaseUrl && thirdPartyApiKey) {
// Override URL with third-party key takes highest priority
apiKey = thirdPartyApiKey;
} else if (isMistralModel) {
// Mistral models use the third-party API key
apiKey = thirdPartyApiKey;
} else if (isOpenAIModel) {
apiKey = globalConfig.get("openaiApiKey");
} else if (isAnthropicModel) {
apiKey = globalConfig.get("anthropicApiKey");
} else {
// Fallback to OpenAI key for unknown models
apiKey = globalConfig.get("openaiApiKey");
}
console.info(`[OpenAI Client] Model: ${model}`);
console.info(`[OpenAI Client] Detected provider: ${isMistralModel ? 'Mistral' : isOpenAIModel ? 'OpenAI' : isAnthropicModel ? 'Anthropic' : 'Unknown'}`);
console.info(`[OpenAI Client] Using base URL: ${baseURL}`);
console.info(`[OpenAI Client] API key configured: ${apiKey ? 'Yes' : 'No'}`);
return new OpenAI({
apiKey,
baseURL,
defaultHeaders: {
'Content-Type': 'application/json',
},
});
}
/** Process system prompt with dynamic variable substitution. */
function processSystemPrompt(
systemPrompt: string,
model: string,
): string {
if (!systemPrompt) {
return "";
}
const now = new Date();
const variables: Record<string, string> = {
date: now.toLocaleDateString(),
time: now.toLocaleTimeString(),
datetime: now.toLocaleString(),
model: model,
timestamp: now.toISOString(),
year: now.getFullYear().toString(),
month: (now.getMonth() + 1).toString().padStart(2, "0"),
day: now.getDate().toString().padStart(2, "0"),
};
let processed = systemPrompt;
for (const [key, value] of Object.entries(variables)) {
const regex = new RegExp(`\\{${key}\\}`, "gi");
processed = processed.replace(regex, value);
}
return processed;
}
/** Validate system prompt for potential issues. */
function validateSystemPrompt(systemPrompt: string): { valid: boolean; warnings: string[] } {
const warnings: string[] = [];
if (systemPrompt.length > 10000) {
warnings.push("System prompt is very long (>10000 chars). This may impact performance.");
}
// Check for unresolved variables
const unresolvedVars = systemPrompt.match(/\{[^}]+\}/g);
if (unresolvedVars) {
const knownVars = ["date", "time", "datetime", "model", "timestamp", "year", "month", "day"];
const unknown = unresolvedVars
.map(v => v.slice(1, -1))
.filter(v => !knownVars.includes(v.toLowerCase()));
if (unknown.length > 0) {
warnings.push(`Unknown variables detected: ${unknown.join(", ")}`);
}
}
return { valid: true, warnings };
}
/** Convert internal chat history to the format expected by OpenAI. */
function toOpenAIMessages(
history: Chat,
systemPrompt?: string,
): ChatCompletionMessageParam[] {
const messages: ChatCompletionMessageParam[] = [];
// Add custom system prompt if provided
if (systemPrompt) {
messages.push({ role: "system", content: systemPrompt });
}
for (const message of history) {
switch (message.getRole()) {
case "system":
// Skip system messages if we already added a custom system prompt
if (!systemPrompt) {
messages.push({ role: "system", content: message.getText() });
}
break;
case "user":
messages.push({ role: "user", content: message.getText() });
break;
case "assistant": {
const toolCalls: ChatCompletionMessageToolCall[] = message
.getToolCallRequests()
.map(toolCall => ({
id: toolCall.id ?? "",
type: "function",
function: {
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments ?? {}),
},
}));
messages.push({
role: "assistant",
content: message.getText(),
...(toolCalls.length ? { tool_calls: toolCalls } : {}),
});
break;
}
case "tool": {
message.getToolCallResults().forEach(toolCallResult => {
messages.push({
role: "tool",
tool_call_id: toolCallResult.toolCallId ?? "",
content: toolCallResult.content,
} as ChatCompletionToolMessageParam);
});
break;
}
}
}
return messages;
}
/** Convert LM Studio tool definitions to OpenAI function-tool descriptors. */
function toOpenAITools(ctl: GeneratorController): ChatCompletionTool[] | undefined {
const tools = ctl.getToolDefinitions().map<ChatCompletionTool>(t => ({
type: "function",
function: {
name: t.function.name,
description: t.function.description,
parameters: t.function.parameters ?? {},
},
}));
return tools.length ? tools : undefined;
}
/* -------------------------------------------------------------------------- */
/* Stream-handling utils */
/* -------------------------------------------------------------------------- */
function wireAbort(ctl: GeneratorController, stream: { controller: AbortController }) {
ctl.onAborted(() => {
console.info("Generation aborted by user.");
stream.controller.abort();
});
}
async function consumeStream(stream: AsyncIterable<any>, ctl: GeneratorController) {
let current: ToolCallState | null = null;
function maybeFlushCurrentToolCall() {
if (current === null || current.name === null) {
return;
}
ctl.toolCallGenerationEnded({
type: "function",
name: current.name,
arguments: JSON.parse(current.arguments),
id: current.id,
});
current = null;
}
for await (const chunk of stream) {
console.info("Received chunk:", JSON.stringify(chunk));
const delta = chunk.choices?.[0]?.delta as
| {
content?: string;
tool_calls?: Array<{
index: number;
id?: string;
function?: { name?: string; arguments?: string };
}>;
}
| undefined;
if (!delta) continue;
/* Text streaming */
if (delta.content) {
ctl.fragmentGenerated(delta.content);
}
/* Tool-call streaming */
for (const toolCall of delta.tool_calls ?? []) {
if (toolCall.id !== undefined) {
maybeFlushCurrentToolCall();
current = { id: toolCall.id, name: null, index: toolCall.index, arguments: "" };
ctl.toolCallGenerationStarted();
}
if (toolCall.function?.name && current) {
current.name = toolCall.function.name;
ctl.toolCallGenerationNameReceived(toolCall.function.name);
}
if (toolCall.function?.arguments && current) {
current.arguments += toolCall.function.arguments;
ctl.toolCallGenerationArgumentFragmentGenerated(toolCall.function.arguments);
}
}
/* Finalize tool call */
if (chunk.choices?.[0]?.finish_reason === "tool_calls" && current?.name) {
maybeFlushCurrentToolCall();
}
}
console.info("Generation completed.");
}
/* -------------------------------------------------------------------------- */
/* API */
/* -------------------------------------------------------------------------- */
export async function generate(ctl: GeneratorController, history: Chat) {
const config = ctl.getPluginConfig(configSchematics);
const model = config.get("model");
const temperature = config.get("temperature");
const top_p = config.get("top_p");
const frequency_penalty = config.get("frequency_penalty");
const presence_penalty = config.get("presence_penalty");
const max_tokens = config.get("max_tokens");
const systemPromptTemplate = config.get("systemPrompt");
// Get global config
const globalConfig = ctl.getGlobalPluginConfig(globalConfigSchematics);
/* 1. Process and validate system prompt */
let processedSystemPrompt: string | undefined;
if (systemPromptTemplate) {
processedSystemPrompt = processSystemPrompt(systemPromptTemplate, model);
const validation = validateSystemPrompt(processedSystemPrompt);
if (validation.warnings.length > 0) {
console.warn("System prompt validation warnings:", validation.warnings);
}
console.info("Using custom system prompt:", processedSystemPrompt);
}
/* 2. Setup client & payload */
const openai = createOpenAI(globalConfig, model);
const messages = toOpenAIMessages(history, processedSystemPrompt);
const tools = toOpenAITools(ctl);
/* 3. Kick off streaming completion */
const stream = await openai.chat.completions.create({
model: model,
messages,
tools,
stream: true,
temperature,
top_p,
frequency_penalty,
presence_penalty,
...(max_tokens > 0 ? { max_tokens } : {}),
});
/* 4. Abort wiring & stream processing */
wireAbort(ctl, stream as any);
await consumeStream(stream as any, ctl);
}