openai-compat-endpoint

Public

src / generator.ts

import { type Chat, type GeneratorController, type InferParsedConfig } from "@lmstudio/sdk";
import OpenAI from "openai";
import {
    type ChatCompletionMessageParam,
    type ChatCompletionMessageToolCall,
    type ChatCompletionTool,
    type ChatCompletionToolMessageParam,
} from "openai/resources/index";
import { configSchematics, globalConfigSchematics } from "./config";

/* -------------------------------------------------------------------------- */
/*                                   Types                                    */
/* -------------------------------------------------------------------------- */

type ToolCallState = {
    id: string;
    name: string | null;
    index: number;
    arguments: string;
};

/* -------------------------------------------------------------------------- */
/*                               Build helpers                                */
/* -------------------------------------------------------------------------- */

/** Build a pre-configured OpenAI client. */
function createOpenAI(
    globalConfig: InferParsedConfig<typeof globalConfigSchematics>,
    model: string,
) {
    const overrideBaseUrl = globalConfig.get("overrideBaseUrl");
    const thirdPartyApiKey = globalConfig.get("thirdPartyApiKey");

    // Detect model provider based on model name
    const isOpenAIModel = model.startsWith("gpt-") || model.startsWith("o1-");
    const isMistralModel = model.startsWith("mistral-") || model.startsWith("codestral-") ||
        model.startsWith("open-mistral-") || model.startsWith("open-codestral-") ||
        model.includes("mistral");
    const isAnthropicModel = model.startsWith("claude-");

    // Determine base URL
    let baseURL: string;
    if (overrideBaseUrl) {
        baseURL = overrideBaseUrl;
    } else if (isMistralModel) {
        baseURL = "https://api.mistral.ai/v1";
    } else if (isOpenAIModel) {
        baseURL = "https://api.openai.com/v1";
    } else if (isAnthropicModel) {
        baseURL = "https://api.anthropic.com/v1";
    } else {
        // Default to OpenAI for unknown models
        baseURL = "https://api.openai.com/v1";
    }

    // Determine API key with proper priority
    let apiKey: string;
    if (overrideBaseUrl && thirdPartyApiKey) {
        // Override URL with third-party key takes highest priority
        apiKey = thirdPartyApiKey;
    } else if (isMistralModel) {
        // Mistral models use the third-party API key
        apiKey = thirdPartyApiKey;
    } else if (isOpenAIModel) {
        apiKey = globalConfig.get("openaiApiKey");
    } else if (isAnthropicModel) {
        apiKey = globalConfig.get("anthropicApiKey");
    } else {
        // Fallback to OpenAI key for unknown models
        apiKey = globalConfig.get("openaiApiKey");
    }

    console.info(`[OpenAI Client] Model: ${model}`);
    console.info(`[OpenAI Client] Detected provider: ${isMistralModel ? 'Mistral' : isOpenAIModel ? 'OpenAI' : isAnthropicModel ? 'Anthropic' : 'Unknown'}`);
    console.info(`[OpenAI Client] Using base URL: ${baseURL}`);
    console.info(`[OpenAI Client] API key configured: ${apiKey ? 'Yes' : 'No'}`);

    return new OpenAI({
        apiKey,
        baseURL,
        defaultHeaders: {
            'Content-Type': 'application/json',
        },
    });
}

/** Process system prompt with dynamic variable substitution. */
function processSystemPrompt(
    systemPrompt: string,
    model: string,
): string {
    if (!systemPrompt) {
        return "";
    }

    const now = new Date();
    const variables: Record<string, string> = {
        date: now.toLocaleDateString(),
        time: now.toLocaleTimeString(),
        datetime: now.toLocaleString(),
        model: model,
        timestamp: now.toISOString(),
        year: now.getFullYear().toString(),
        month: (now.getMonth() + 1).toString().padStart(2, "0"),
        day: now.getDate().toString().padStart(2, "0"),
    };

    let processed = systemPrompt;
    for (const [key, value] of Object.entries(variables)) {
        const regex = new RegExp(`\\{${key}\\}`, "gi");
        processed = processed.replace(regex, value);
    }

    return processed;
}

/** Validate system prompt for potential issues. */
function validateSystemPrompt(systemPrompt: string): { valid: boolean; warnings: string[] } {
    const warnings: string[] = [];

    if (systemPrompt.length > 10000) {
        warnings.push("System prompt is very long (>10000 chars). This may impact performance.");
    }

    // Check for unresolved variables
    const unresolvedVars = systemPrompt.match(/\{[^}]+\}/g);
    if (unresolvedVars) {
        const knownVars = ["date", "time", "datetime", "model", "timestamp", "year", "month", "day"];
        const unknown = unresolvedVars
            .map(v => v.slice(1, -1))
            .filter(v => !knownVars.includes(v.toLowerCase()));
        if (unknown.length > 0) {
            warnings.push(`Unknown variables detected: ${unknown.join(", ")}`);
        }
    }

    return { valid: true, warnings };
}

/** Convert internal chat history to the format expected by OpenAI. */
function toOpenAIMessages(
    history: Chat,
    systemPrompt?: string,
): ChatCompletionMessageParam[] {
    const messages: ChatCompletionMessageParam[] = [];

    // Add custom system prompt if provided
    if (systemPrompt) {
        messages.push({ role: "system", content: systemPrompt });
    }

    for (const message of history) {
        switch (message.getRole()) {
            case "system":
                // Skip system messages if we already added a custom system prompt
                if (!systemPrompt) {
                    messages.push({ role: "system", content: message.getText() });
                }
                break;

            case "user":
                messages.push({ role: "user", content: message.getText() });
                break;

            case "assistant": {
                const toolCalls: ChatCompletionMessageToolCall[] = message
                    .getToolCallRequests()
                    .map(toolCall => ({
                        id: toolCall.id ?? "",
                        type: "function",
                        function: {
                            name: toolCall.name,
                            arguments: JSON.stringify(toolCall.arguments ?? {}),
                        },
                    }));

                messages.push({
                    role: "assistant",
                    content: message.getText(),
                    ...(toolCalls.length ? { tool_calls: toolCalls } : {}),
                });
                break;
            }

            case "tool": {
                message.getToolCallResults().forEach(toolCallResult => {
                    messages.push({
                        role: "tool",
                        tool_call_id: toolCallResult.toolCallId ?? "",
                        content: toolCallResult.content,
                    } as ChatCompletionToolMessageParam);
                });
                break;
            }
        }
    }

    return messages;
}

/** Convert LM Studio tool definitions to OpenAI function-tool descriptors. */
function toOpenAITools(ctl: GeneratorController): ChatCompletionTool[] | undefined {
    const tools = ctl.getToolDefinitions().map<ChatCompletionTool>(t => ({
        type: "function",
        function: {
            name: t.function.name,
            description: t.function.description,
            parameters: t.function.parameters ?? {},
        },
    }));
    return tools.length ? tools : undefined;
}

/* -------------------------------------------------------------------------- */
/*                            Stream-handling utils                           */
/* -------------------------------------------------------------------------- */

function wireAbort(ctl: GeneratorController, stream: { controller: AbortController }) {
    ctl.onAborted(() => {
        console.info("Generation aborted by user.");
        stream.controller.abort();
    });
}

async function consumeStream(stream: AsyncIterable<any>, ctl: GeneratorController) {
    let current: ToolCallState | null = null;

    function maybeFlushCurrentToolCall() {
        if (current === null || current.name === null) {
            return;
        }
        ctl.toolCallGenerationEnded({
            type: "function",
            name: current.name,
            arguments: JSON.parse(current.arguments),
            id: current.id,
        });
        current = null;
    }

    for await (const chunk of stream) {
        console.info("Received chunk:", JSON.stringify(chunk));
        const delta = chunk.choices?.[0]?.delta as
            | {
                content?: string;
                tool_calls?: Array<{
                    index: number;
                    id?: string;
                    function?: { name?: string; arguments?: string };
                }>;
            }
            | undefined;

        if (!delta) continue;

        /* Text streaming */
        if (delta.content) {
            ctl.fragmentGenerated(delta.content);
        }

        /* Tool-call streaming */
        for (const toolCall of delta.tool_calls ?? []) {
            if (toolCall.id !== undefined) {
                maybeFlushCurrentToolCall();
                current = { id: toolCall.id, name: null, index: toolCall.index, arguments: "" };
                ctl.toolCallGenerationStarted();
            }

            if (toolCall.function?.name && current) {
                current.name = toolCall.function.name;
                ctl.toolCallGenerationNameReceived(toolCall.function.name);
            }

            if (toolCall.function?.arguments && current) {
                current.arguments += toolCall.function.arguments;
                ctl.toolCallGenerationArgumentFragmentGenerated(toolCall.function.arguments);
            }
        }

        /* Finalize tool call */
        if (chunk.choices?.[0]?.finish_reason === "tool_calls" && current?.name) {
            maybeFlushCurrentToolCall();
        }
    }

    console.info("Generation completed.");
}

/* -------------------------------------------------------------------------- */
/*                                     API                                    */
/* -------------------------------------------------------------------------- */

export async function generate(ctl: GeneratorController, history: Chat) {
    const config = ctl.getPluginConfig(configSchematics);
    const model = config.get("model");
    const temperature = config.get("temperature");
    const top_p = config.get("top_p");
    const frequency_penalty = config.get("frequency_penalty");
    const presence_penalty = config.get("presence_penalty");
    const max_tokens = config.get("max_tokens");
    const systemPromptTemplate = config.get("systemPrompt");

    // Get global config
    const globalConfig = ctl.getGlobalPluginConfig(globalConfigSchematics);

    /* 1. Process and validate system prompt */
    let processedSystemPrompt: string | undefined;
    if (systemPromptTemplate) {
        processedSystemPrompt = processSystemPrompt(systemPromptTemplate, model);
        const validation = validateSystemPrompt(processedSystemPrompt);

        if (validation.warnings.length > 0) {
            console.warn("System prompt validation warnings:", validation.warnings);
        }

        console.info("Using custom system prompt:", processedSystemPrompt);
    }

    /* 2. Setup client & payload */
    const openai = createOpenAI(globalConfig, model);
    const messages = toOpenAIMessages(history, processedSystemPrompt);
    const tools = toOpenAITools(ctl);

    /* 3. Kick off streaming completion */
    const stream = await openai.chat.completions.create({
        model: model,
        messages,
        tools,
        stream: true,
        temperature,
        top_p,
        frequency_penalty,
        presence_penalty,
        ...(max_tokens > 0 ? { max_tokens } : {}),
    });

    /* 4. Abort wiring & stream processing */
    wireAbort(ctl, stream as any);
    await consumeStream(stream as any, ctl);
}