Forked from danielsig/visit-website
src / toolsProvider.ts
import { tool, Tool, ToolsProviderController } from "@lmstudio/sdk";
import { writeFile } from "fs/promises";
import { join } from "path";
import { z } from "zod";
import { configSchematics } from "./config";
import { fetchHTML, getImageHeaders } from "./browser";
import { extractImages, extractLinks, extractTextContent } from "./extractor";
import { cleanTag, undefinedIfAuto } from "./utils";
import { configureUndiciDispatcher } from "./tls";
export async function toolsProvider(ctl: ToolsProviderController): Promise<Tool[]> {
const tools: Tool[] = [];
// Initialize the advanced TLS agent on startup
await configureUndiciDispatcher();
const viewImagesTool = tool({
name: "View Images",
description: "Download images from a website or a list of image URLs to make them viewable.",
parameters: {
imageURLs: z.array(z.string().url()).optional().describe("List of image URLs to view that were not obtained via the Visit Website tool."),
websiteURL: z.string().url().optional().describe("The URL of the website, whose images to view."),
maxImages: z.number().int().min(1).max(200).optional().describe("Maximum number of images to view when websiteURL is provided."),
},
implementation: async ({ imageURLs, websiteURL, maxImages }, { status, warn, signal }) => {
try {
maxImages = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("maxImages"), -1) ?? maxImages ?? 10;
const imageURLsToDownload = imageURLs || [];
if (websiteURL) {
status("Fetching image URLs from website...");
const { body } = await fetchHTML(websiteURL, signal, warn);
const images = extractImages(body, websiteURL, maxImages).map(x => x[1]);
imageURLsToDownload.push(...images);
}
status("Downloading images...");
const workingDirectory = ctl.getWorkingDirectory();
const timestamp = Date.now();
const downloadPromises = imageURLsToDownload.map(async (url: string, i: number) => {
if (url.startsWith(workingDirectory)) return url;
const index = i + 1;
try {
const referer = websiteURL || new URL(url).origin;
const headers = getImageHeaders(url, referer);
const imageResponse = await fetch(url, {
method: "GET",
signal,
headers,
});
if (!imageResponse.ok) {
warn(`Failed to fetch image ${index} (${imageResponse.status}): ${url}`);
return null;
}
const arrayBuffer = await imageResponse.arrayBuffer();
const bytes = Buffer.from(arrayBuffer);
if (bytes.length === 0) {
warn(`Image ${index} is empty: ${url}`);
return null;
}
const fileExtension = /image\/([\w]+)/.exec(imageResponse.headers.get('content-type') || '')?.[1] || /\.([\w]+)(?:\?.*)?$/.exec(url)?.[1] || 'jpg';
const fileName = `${timestamp}-${index}.${fileExtension}`;
const filePath = join(workingDirectory, fileName);
const localPath = filePath.split(join(workingDirectory, ''))[1] || fileName;
await writeFile(filePath, bytes, 'binary');
return localPath.replace(/\\/g, '/');
} catch (error: any) {
if (error instanceof DOMException && error.name === "AbortError") return null;
warn(`Error fetching image ${index}: ${error.message}`);
return null;
}
});
const downloadedImageMarkdowns = (await Promise.all(downloadPromises))
.map((x, i) => x ? `` : 'Error fetching image from URL: ' + imageURLsToDownload[i]);
if (downloadedImageMarkdowns.length === 0) {
warn('Error fetching images');
return imageURLsToDownload;
}
status(`Downloaded ${downloadedImageMarkdowns.length} images successfully.`);
return downloadedImageMarkdowns;
} catch (error: any) {
if (error instanceof DOMException && error.name === "AbortError") {
return "Image download aborted by user.";
}
console.error(error);
warn(`Error during image download: ${error.message}`);
return `Error: ${error.message}`;
}
}
});
const visitWebsiteTool = tool({
name: "Visit Website",
description: "Visit a website and return its title, headings, links, images, and text content. Images are automatically downloaded and viewable.",
parameters: {
url: z.string().url().describe("The URL of the website to visit"),
findInPage: z.array(z.string()).optional().describe("Highly recommended! Optional search terms to prioritize which links, images, and content to return."),
maxLinks: z.number().int().min(0).max(200).optional().describe("Maximum number of links to extract from the page."),
maxImages: z.number().int().min(0).max(200).optional().describe("Maximum number of images to extract from the page."),
contentLimit: z.number().int().min(0).max(100_000).optional().describe("Maximum text content length to extract from the page."),
},
implementation: async ({ url, maxLinks, maxImages, contentLimit, findInPage: searchTerms }, context) => {
const { status, warn, signal } = context;
status("Visiting website...");
try {
maxLinks = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("maxLinks"), -1) ?? maxLinks ?? 40;
maxImages = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("maxImages"), -1) ?? maxImages ?? 10;
contentLimit = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("contentLimit"), -1) ?? contentLimit ?? 60000;
const { head, body } = await fetchHTML(url, signal, warn);
status("Website visited successfully.");
const title = cleanTag(head.match(/<title[^>]*>([\s\S]*?)<\/title>/i)?.[1] || "");
const h1 = cleanTag(body.match(/<h1[^>]*>([\s\S]*?)<\/h1>/i)?.[1] || "");
const h2 = cleanTag(body.match(/<h2[^>]*>([\s\S]*?)<\/h2>/i)?.[1] || "");
const h3 = cleanTag(body.match(/<h3[^>]*>([\s\S]*?)<\/h3>/i)?.[1] || "");
const links = maxLinks && extractLinks(body, url, maxLinks, searchTerms);
const imagesToFetch = maxImages ? extractImages(body, url, maxImages, searchTerms) : [];
const imagesResult = maxImages && imagesToFetch.length > 0
? await viewImagesTool.implementation({ imageURLs: imagesToFetch.map(x => x[1]), websiteURL: url }, context)
: [];
const images = Array.isArray(imagesResult)
? imagesResult.map((markdown, index) => [imagesToFetch[index][0], markdown] as [string, string])
: [];
let content = "";
if (contentLimit) {
content = extractTextContent(body, contentLimit, searchTerms);
}
return {
url, title, h1, h2, h3,
...(links ? { links } : {}),
...(images && images.length > 0 ? { images } : {}),
...(content ? { content } : {}),
};
} catch (error: any) {
if (error instanceof DOMException && error.name === "AbortError") {
return "Website visit aborted by user.";
}
console.error(error);
warn(`Error during website visit: ${error.message}`);
return `Error: ${error.message}`;
}
},
});
tools.push(visitWebsiteTool);
tools.push(viewImagesTool);
return tools;
}