src / toolsProvider.ts

import { tool, Tool, ToolsProviderController } from "@lmstudio/sdk";
import { writeFile } from "fs/promises";
import { join } from "path";
import { z } from "zod";
import { configSchematics } from "./config";
import { fetchHTML, getImageHeaders } from "./browser";
import { extractImages, extractLinks, extractTextContent } from "./extractor";
import { cleanTag, undefinedIfAuto } from "./utils";
import { configureUndiciDispatcher } from "./tls";

export async function toolsProvider(ctl: ToolsProviderController): Promise<Tool[]> {
	const tools: Tool[] = [];

	// Initialize the advanced TLS agent on startup
	await configureUndiciDispatcher();

	const viewImagesTool = tool({
		name: "View Images",
		description: "Download images from a website or a list of image URLs to make them viewable.",
		parameters: {
			imageURLs: z.array(z.string().url()).optional().describe("List of image URLs to view that were not obtained via the Visit Website tool."),
			websiteURL: z.string().url().optional().describe("The URL of the website, whose images to view."),
			maxImages: z.number().int().min(1).max(200).optional().describe("Maximum number of images to view when websiteURL is provided."),
		},
		implementation: async ({ imageURLs, websiteURL, maxImages }, { status, warn, signal }) => {
			try {
				maxImages = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("maxImages"), -1) ?? maxImages ?? 10;

				const imageURLsToDownload = imageURLs || [];

				if (websiteURL) {
					status("Fetching image URLs from website...");
					const { body } = await fetchHTML(websiteURL, signal, warn);
					const images = extractImages(body, websiteURL, maxImages).map(x => x[1]);
					imageURLsToDownload.push(...images);
				}

				status("Downloading images...");
				const workingDirectory = ctl.getWorkingDirectory();
				const timestamp = Date.now();
				
				const downloadPromises = imageURLsToDownload.map(async (url: string, i: number) => {
					if (url.startsWith(workingDirectory)) return url;
					
					const index = i + 1;
					try {
						const referer = websiteURL || new URL(url).origin;
						const headers = getImageHeaders(url, referer);
						
						const imageResponse = await fetch(url, {
							method: "GET",
							signal,
							headers,
						});
						
						if (!imageResponse.ok) {
							warn(`Failed to fetch image ${index} (${imageResponse.status}): ${url}`);
							return null;
						}
						
						const arrayBuffer = await imageResponse.arrayBuffer();
						const bytes = Buffer.from(arrayBuffer);

						if (bytes.length === 0) {
							warn(`Image ${index} is empty: ${url}`);
							return null;
						}
						const fileExtension = /image\/([\w]+)/.exec(imageResponse.headers.get('content-type') || '')?.[1] || /\.([\w]+)(?:\?.*)?$/.exec(url)?.[1] || 'jpg';
						const fileName = `${timestamp}-${index}.${fileExtension}`;
						const filePath = join(workingDirectory, fileName);
						const localPath = filePath.split(join(workingDirectory, ''))[1] || fileName;
						
						await writeFile(filePath, bytes, 'binary');
						return localPath.replace(/\\/g, '/');
					} catch (error: any) {
						if (error instanceof DOMException && error.name === "AbortError") return null;
						warn(`Error fetching image ${index}: ${error.message}`);
						return null;
					}
				});
				const downloadedImageMarkdowns = (await Promise.all(downloadPromises))
					.map((x, i) => x ? `![Image ${i + 1}](${x})` : 'Error fetching image from URL: ' + imageURLsToDownload[i]);
				if (downloadedImageMarkdowns.length === 0) {
					warn('Error fetching images');
					return imageURLsToDownload;
				}

				status(`Downloaded ${downloadedImageMarkdowns.length} images successfully.`);
				return downloadedImageMarkdowns;
			} catch (error: any) {
				if (error instanceof DOMException && error.name === "AbortError") {
					return "Image download aborted by user.";
				}
				console.error(error);
				warn(`Error during image download: ${error.message}`);
				return `Error: ${error.message}`;
			}
		}
	});

	const visitWebsiteTool = tool({
		name: "Visit Website",
		description: "Visit a website and return its title, headings, links, images, and text content. Images are automatically downloaded and viewable.",
		parameters: {
			url: z.string().url().describe("The URL of the website to visit"),
			findInPage: z.array(z.string()).optional().describe("Highly recommended! Optional search terms to prioritize which links, images, and content to return."),
			maxLinks: z.number().int().min(0).max(200).optional().describe("Maximum number of links to extract from the page."),
			maxImages: z.number().int().min(0).max(200).optional().describe("Maximum number of images to extract from the page."),
			contentLimit: z.number().int().min(0).max(100_000).optional().describe("Maximum text content length to extract from the page."),
		},
		implementation: async ({ url, maxLinks, maxImages, contentLimit, findInPage: searchTerms }, context) => {
			const { status, warn, signal } = context;
			status("Visiting website...");

			try {
				maxLinks = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("maxLinks"), -1) ?? maxLinks ?? 40;
				maxImages = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("maxImages"), -1) ?? maxImages ?? 10;
				contentLimit = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("contentLimit"), -1) ?? contentLimit ?? 60000;

				const { head, body } = await fetchHTML(url, signal, warn);
				status("Website visited successfully.");
				
				const title = cleanTag(head.match(/<title[^>]*>([\s\S]*?)<\/title>/i)?.[1] || "");
				const h1 = cleanTag(body.match(/<h1[^>]*>([\s\S]*?)<\/h1>/i)?.[1] || "");
				const h2 = cleanTag(body.match(/<h2[^>]*>([\s\S]*?)<\/h2>/i)?.[1] || "");
				const h3 = cleanTag(body.match(/<h3[^>]*>([\s\S]*?)<\/h3>/i)?.[1] || "");
				
				const links = maxLinks && extractLinks(body, url, maxLinks, searchTerms);
				const imagesToFetch = maxImages ? extractImages(body, url, maxImages, searchTerms) : [];
				
				const imagesResult = maxImages && imagesToFetch.length > 0 
					? await viewImagesTool.implementation({ imageURLs: imagesToFetch.map(x => x[1]), websiteURL: url }, context) 
					: [];
				
				const images = Array.isArray(imagesResult)
					? imagesResult.map((markdown, index) => [imagesToFetch[index][0], markdown] as [string, string])
					: [];

				let content = "";
				if (contentLimit) {
					content = extractTextContent(body, contentLimit, searchTerms);
				}
				
				return {
					url, title, h1, h2, h3,
					...(links ? { links } : {}),
					...(images && images.length > 0 ? { images } : {}),
					...(content ? { content } : {}),
				};
			} catch (error: any) {
				if (error instanceof DOMException && error.name === "AbortError") {
					return "Website visit aborted by user.";
				}
				console.error(error);
				warn(`Error during website visit: ${error.message}`);
				return `Error: ${error.message}`;
			}
		},
	});

	tools.push(visitWebsiteTool);
	tools.push(viewImagesTool);
	return tools;
}