2 Downloads
Forked from lmstudio/openai-compat-endpoint
src / generator.ts
import { type Chat, type GeneratorController, type InferParsedConfig } from "@lmstudio/sdk";
import { CONFIG_DEFAULTS, configSchematics, globalConfigSchematics } from "./config";
/* -------------------------------------------------------------------------- */
/* Types */
/* -------------------------------------------------------------------------- */
type ChatCompletionMessageParam =
| { role: "system"; content: string }
| { role: "user"; content: string }
| {
role: "assistant";
content: string;
tool_calls?: ChatCompletionMessageToolCall[];
}
| { role: "tool"; content: string; tool_call_id: string };
type ChatCompletionToolMessageParam = ChatCompletionMessageParam;
type ChatCompletionMessageToolCall = {
id: string;
type: "function";
function: { name: string; arguments: string };
};
type ChatCompletionTool = {
type: "function";
function: {
name: string;
description?: string;
parameters?: Record<string, unknown>;
};
};
type ChatCompletionChunk = {
choices?: Array<{
delta?: {
content?: string;
tool_calls?: Array<{
index: number;
id?: string;
function?: { name?: string; arguments?: string };
}>;
};
finish_reason?: string | null;
}>;
};
type ChatCompletionRequestBody = {
model: string;
messages: ChatCompletionMessageParam[];
tools?: ChatCompletionTool[];
stream: true;
};
type ToolCallState = {
id: string;
name: string | null;
index: number;
arguments: string;
};
type StreamHandle = {
controller: AbortController;
stream: AsyncIterable<ChatCompletionChunk>;
};
const REQUEST_TIMEOUT_MS = 15_000;
const MAX_RETRIES = 3;
/* -------------------------------------------------------------------------- */
/* Build helpers */
/* -------------------------------------------------------------------------- */
/** Convert internal chat history to the format expected by OpenAI. */
function toOpenAIMessages(history: Chat): ChatCompletionMessageParam[] {
const messages: ChatCompletionMessageParam[] = [];
for (const message of history) {
switch (message.getRole()) {
case "system":
messages.push({ role: "system", content: message.getText() });
break;
case "user":
messages.push({ role: "user", content: message.getText() });
break;
case "assistant": {
const toolCalls: ChatCompletionMessageToolCall[] = message
.getToolCallRequests()
.map(toolCall => ({
id: toolCall.id ?? "",
type: "function",
function: {
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments ?? {}),
},
}));
messages.push({
role: "assistant",
content: message.getText(),
...(toolCalls.length ? { tool_calls: toolCalls } : {}),
});
break;
}
case "tool": {
message.getToolCallResults().forEach(toolCallResult => {
messages.push({
role: "tool",
tool_call_id: toolCallResult.toolCallId ?? "",
content: toolCallResult.content,
} as ChatCompletionToolMessageParam);
});
break;
}
}
}
return messages;
}
/** Convert LM Studio tool definitions to OpenAI function-tool descriptors. */
function toOpenAITools(ctl: GeneratorController): ChatCompletionTool[] | undefined {
const tools = ctl.getToolDefinitions().map<ChatCompletionTool>(t => ({
type: "function",
function: {
name: t.function.name,
description: t.function.description,
parameters: t.function.parameters ?? {},
},
}));
return tools.length ? tools : undefined;
}
/* -------------------------------------------------------------------------- */
/* Stream-handling utils */
/* -------------------------------------------------------------------------- */
function wireAbort(ctl: GeneratorController, stream: { controller: AbortController }) {
ctl.onAborted(() => {
console.info("Generation aborted by user.");
stream.controller.abort();
});
}
async function consumeStream(stream: AsyncIterable<any>, ctl: GeneratorController) {
let current: ToolCallState | null = null;
function maybeFlushCurrentToolCall() {
if (current === null || current.name === null) {
return;
}
ctl.toolCallGenerationEnded({
type: "function",
name: current.name,
arguments: JSON.parse(current.arguments),
id: current.id,
});
current = null;
}
for await (const chunk of stream) {
console.info("Received chunk:", JSON.stringify(chunk));
const delta = chunk.choices?.[0]?.delta as
| {
content?: string;
tool_calls?: Array<{
index: number;
id?: string;
function?: { name?: string; arguments?: string };
}>;
}
| undefined;
if (!delta) continue;
/* Text streaming */
if (delta.content) {
ctl.fragmentGenerated(delta.content);
}
/* Tool-call streaming */
for (const toolCall of delta.tool_calls ?? []) {
if (toolCall.id !== undefined) {
maybeFlushCurrentToolCall();
current = { id: toolCall.id, name: null, index: toolCall.index, arguments: "" };
ctl.toolCallGenerationStarted();
}
if (toolCall.function?.name && current) {
current.name = toolCall.function.name;
ctl.toolCallGenerationNameReceived(toolCall.function.name);
}
if (toolCall.function?.arguments && current) {
current.arguments += toolCall.function.arguments;
ctl.toolCallGenerationArgumentFragmentGenerated(toolCall.function.arguments);
}
}
/* Finalize tool call */
if (chunk.choices?.[0]?.finish_reason === "tool_calls" && current?.name) {
maybeFlushCurrentToolCall();
}
}
console.info("Generation completed.");
}
async function readErrorSnippet(response: Response) {
try {
const text = await response.text();
return text.slice(0, 400);
} catch {
return "<unable to read body>";
}
}
async function wait(ms: number) {
return new Promise(resolve => setTimeout(resolve, ms));
}
function combineSignals(externalSignal: AbortSignal | undefined, timeoutMs: number) {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(new Error("Request timed out")), timeoutMs);
if (externalSignal) {
if (externalSignal.aborted) {
controller.abort(externalSignal.reason);
} else {
const onAbort = () => {
controller.abort(externalSignal.reason);
};
externalSignal.addEventListener("abort", onAbort, { once: true });
}
}
return {
signal: controller.signal,
dispose: () => clearTimeout(timeoutId),
};
}
async function fetchWithRetry(
url: string,
init: Omit<RequestInit, "signal">,
externalSignal: AbortSignal,
): Promise<Response> {
for (let attempt = 1; attempt <= MAX_RETRIES; attempt++) {
const { signal, dispose } = combineSignals(externalSignal, REQUEST_TIMEOUT_MS);
try {
const response = await fetch(url, { ...init, signal });
if (response.status >= 500) {
const snippet = await readErrorSnippet(response);
console.error(
`[IONOS] ${response.status} on attempt ${attempt}: ${snippet}`,
);
if (attempt === MAX_RETRIES) {
throw new Error(
`IONOS API returned ${response.status}: ${snippet}`,
);
}
const backoffMs = 1000 * Math.pow(2, attempt - 1);
await wait(backoffMs);
continue;
}
if (!response.ok) {
const snippet = await readErrorSnippet(response);
console.error(`[IONOS] ${response.status}: ${snippet}`);
throw new Error(`Request failed with ${response.status}: ${snippet}`);
}
return response;
} catch (error) {
if ((error as Error).name === "AbortError" && externalSignal.aborted) {
throw error;
}
if (attempt === MAX_RETRIES) {
throw error;
}
const backoffMs = 1000 * Math.pow(2, attempt - 1);
await wait(backoffMs);
} finally {
dispose();
}
}
throw new Error("Failed to call IONOS API after retries.");
}
async function* parseSseStream(body: ReadableStream<Uint8Array>): AsyncGenerator<ChatCompletionChunk> {
const reader = body.getReader();
const decoder = new TextDecoder("utf-8");
let buffer = "";
try {
while (true) {
const { value, done } = await reader.read();
if (done) {
break;
}
buffer += decoder.decode(value, { stream: true });
let delimiterIndex: number;
while ((delimiterIndex = buffer.indexOf("\n\n")) !== -1) {
const rawEvent = buffer.slice(0, delimiterIndex).trim();
buffer = buffer.slice(delimiterIndex + 2);
if (!rawEvent.startsWith("data:")) {
continue;
}
const dataLines = rawEvent
.split("\n")
.filter(line => line.startsWith("data:"))
.map(line => line.slice(5).trimStart());
const dataPayload = dataLines.join("\n");
if (!dataPayload) continue;
if (dataPayload === "[DONE]") {
return;
}
try {
const parsed = JSON.parse(dataPayload) as ChatCompletionChunk;
yield parsed;
} catch (error) {
console.error("Failed to parse SSE chunk:", error);
}
}
}
} finally {
reader.releaseLock();
}
}
async function createChatCompletionStream(
apiBase: string,
apiKey: string,
body: ChatCompletionRequestBody,
): Promise<StreamHandle> {
const controller = new AbortController();
const response = await fetchWithRetry(
`${apiBase.replace(/\/$/, "")}/chat/completions`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
},
body: JSON.stringify(body),
},
controller.signal,
);
if (!response.body) {
throw new Error("IONOS API response had no body to stream.");
}
return {
controller,
stream: parseSseStream(response.body),
};
}
function resolveApiKey(globalConfig: InferParsedConfig<typeof globalConfigSchematics>) {
const keyFromUi = globalConfig.get("openaiApiKey")?.trim();
const keyFromEnv = process.env.IONOS_API_KEY?.trim();
return keyFromUi || keyFromEnv || "";
}
/* -------------------------------------------------------------------------- */
/* API */
/* -------------------------------------------------------------------------- */
export async function generate(ctl: GeneratorController, history: Chat) {
const config = ctl.getPluginConfig(configSchematics);
const model = config.get("model") || CONFIG_DEFAULTS.DEFAULT_MODEL;
const apiBase = config.get("apiBaseUrl")?.trim() || CONFIG_DEFAULTS.DEFAULT_API_BASE;
const globalConfig = ctl.getGlobalPluginConfig(globalConfigSchematics);
const apiKey = resolveApiKey(globalConfig);
if (!apiKey) {
throw new Error(
"Kein IONOS API Key gefunden. Bitte im Plugin-UI 'IONOS API Key' setzen oder IONOS_API_KEY in der Umgebung definieren.",
);
}
/* 1. Setup client & payload */
const messages = toOpenAIMessages(history);
const tools = toOpenAITools(ctl);
const body: ChatCompletionRequestBody = {
model,
messages,
stream: true,
...(tools ? { tools } : {}),
};
/* 2. Kick off streaming completion */
const { controller, stream } = await createChatCompletionStream(apiBase, apiKey, body);
/* 3. Abort wiring & stream processing */
wireAbort(ctl, { controller });
await consumeStream(stream, ctl);
}