Project Files
src / fastvlm_server / server.py
"""FastAPI server for FastVLM image analysis."""
from __future__ import annotations
import base64
import io
import logging
import os
import signal
import time
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel as PydanticModel, Field
from typing import List, Optional
from PIL import Image
from .config import get_config
from .model import get_model, BaseModel as VLMBaseModel
logger = logging.getLogger(__name__)
# Global model instance
_model: VLMBaseModel | None = None
def get_loaded_model() -> VLMBaseModel:
global _model
if _model is None:
raise RuntimeError("Model not loaded")
return _model
@asynccontextmanager
async def lifespan(app: FastAPI):
global _model
config = get_config()
if config.lazy:
logger.info("Lazy mode: models will load on first request")
else:
if config.model_path:
logger.info(f"Loading model from {config.model_path} with backend={config.resolved_backend}...")
start = time.time()
_model = get_model(config)
elapsed = time.time() - start
logger.info(f"Model loaded in {elapsed:.2f}s (backend: {config.resolved_backend})")
else:
logger.info("No model path configured — running in detection-only mode")
if config.florence2_model_path:
from .florence2_model import load as f2load
logger.info(f"Loading Florence-2 from {config.florence2_model_path}...")
f2start = time.time()
f2load(config.florence2_model_path)
f2elapsed = time.time() - f2start
logger.info(f"Florence-2 loaded in {f2elapsed:.2f}s")
yield
logger.info("Server shutting down")
_model = None
app = FastAPI(
title="FastVLM Server",
description="HTTP API for FastVLM image analysis",
version="3.0.0",
lifespan=lifespan
)
# === Request / Response Models ===
class ImageInput(PydanticModel):
"""Single image with client-provided identifier."""
id: str = Field(..., description="Client-provided identifier (e.g. 'a1', 'v3')")
data: str = Field(..., description="Base64-encoded image data (JPEG/PNG/WebP)")
class ImageResult(PydanticModel):
"""Result for a single image analysis."""
id: str # Echoed from input
text: str
inference_time_ms: float
class AnalyzeRequest(PydanticModel):
"""Request body - supports single image OR array of images with IDs."""
image: Optional[str] = Field(default=None, description="Single base64-encoded image (deprecated, use images)")
images: Optional[List[ImageInput]] = Field(default=None, description="Array of images with IDs")
prompt: str = Field(default="Describe this image.", description="Question or instruction (applied to all images)")
max_tokens: Optional[int] = Field(default=None, ge=1, le=4096, description="Max tokens (default: server config)")
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature (default: server config)")
class AnalyzeResponse(PydanticModel):
"""Response - always returns array of results, order matches input."""
results: List[ImageResult]
total_inference_time_ms: float
backend: str
class HealthResponse(PydanticModel):
status: str
model_loaded: bool
florence2_loaded: Optional[bool]
backend: Optional[str]
class StatusResponse(PydanticModel):
model_path: str
backend: str
host: str
port: int
pid: int
detect_backend: str
# === Endpoints ===
@app.get("/health", response_model=HealthResponse)
async def health():
config = get_config()
f2_loaded = None
if config.florence2_model_path:
from .florence2_model import _model as _f2model
f2_loaded = _f2model is not None
return HealthResponse(
status="ok",
model_loaded=_model is not None,
florence2_loaded=f2_loaded,
backend=config.resolved_backend if _model else None
)
@app.get("/status", response_model=StatusResponse)
async def status():
config = get_config()
return StatusResponse(
model_path=config.model_path,
backend=config.resolved_backend,
host=config.host,
port=config.port,
pid=os.getpid(),
detect_backend=config.detect_backend,
)
@app.post("/analyze", response_model=AnalyzeResponse)
async def analyze(request: AnalyzeRequest):
global _model
# Normalize input: accept 'images' (preferred) or deprecated 'image' (single)
if request.images:
image_list = request.images
elif request.image:
# Backward compat: wrap single image with auto-generated ID
image_list = [ImageInput(id="img_0", data=request.image)]
else:
raise HTTPException(status_code=400, detail="Either 'image' or 'images' must be provided")
if len(image_list) == 0:
raise HTTPException(status_code=400, detail="At least one image is required")
if len(image_list) > 16:
raise HTTPException(status_code=400, detail="Maximum 16 images per request")
if _model is None:
config2 = get_config()
if not config2.model_path:
raise HTTPException(status_code=503, detail="FastVLM model not loaded. Configure mlxVisionModelPath and reload.")
logger.info(f"Lazy-loading FastVLM from {config2.model_path}...")
start = time.time()
_model = get_model(config2)
logger.info(f"FastVLM loaded in {time.time() - start:.2f}s (backend: {config2.resolved_backend})")
model = get_loaded_model()
config = get_config()
# Fallback to server config if not provided in request
effective_max_tokens = request.max_tokens if request.max_tokens is not None else config.max_tokens
effective_temperature = request.temperature if request.temperature is not None else config.temperature
logger.info(
f"/analyze: {len(image_list)} image(s), "
f"prompt={request.prompt!r}, "
f"max_tokens={effective_max_tokens}, temperature={effective_temperature}"
)
results: list[ImageResult] = []
total_start = time.time()
for img in image_list:
try:
image_data = base64.b64decode(img.data)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image '{img.id}': {e}")
start = time.time()
try:
text = model.generate(
image=image,
prompt=request.prompt,
max_tokens=effective_max_tokens,
temperature=effective_temperature
)
except Exception as e:
logger.error(f"Inference error for '{img.id}': {e}")
raise HTTPException(status_code=500, detail=f"Inference failed for '{img.id}': {e}")
elapsed_ms = (time.time() - start) * 1000
results.append(ImageResult(id=img.id, text=text, inference_time_ms=round(elapsed_ms, 2)))
total_elapsed_ms = (time.time() - total_start) * 1000
return AnalyzeResponse(
results=results,
total_inference_time_ms=round(total_elapsed_ms, 2),
backend=config.resolved_backend
)
# === /detect — Florence-2 Object Detection ===
class DetectImageInput(PydanticModel):
id: str = Field(..., description="Client-provided identifier")
data: str = Field(..., description="Base64-encoded image data (JPEG/PNG/WebP)")
class DetectResult(PydanticModel):
id: str
bboxes: List[List[float]]
labels: List[str]
width: int
height: int
inference_time_ms: float
class DetectRequest(PydanticModel):
images: List[DetectImageInput]
task: str = Field(default="<OD>", description="Florence-2 task token, e.g. '<OD>' or '<OPEN_VOCABULARY_DETECTION>a person'")
class DetectResponse(PydanticModel):
results: List[DetectResult]
total_inference_time_ms: float
backend: str = "florence2-mlx"
@app.post("/detect", response_model=DetectResponse)
async def detect(request: DetectRequest):
if len(request.images) == 0:
raise HTTPException(status_code=400, detail="At least one image is required")
if len(request.images) > 16:
raise HTTPException(status_code=400, detail="Maximum 16 images per request")
config = get_config()
if config.detect_backend == "qwen3-vl":
from . import qwen3_vl_model
q3path = config.qwen3_vl_model_path
if not q3path:
raise HTTPException(status_code=503, detail="Qwen3-VL not configured (server started without --qwen3-vl-model-path)")
try:
qwen3_vl_model.load(q3path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Qwen3-VL model load failed: {e}")
results: list[DetectResult] = []
total_start = time.time()
for img in request.images:
t0 = time.time()
try:
r = qwen3_vl_model.detect(img.data, task=request.task)
except Exception as e:
logger.error(f"Qwen3-VL inference error for '{img.id}': {e}")
raise HTTPException(status_code=500, detail=f"Inference failed for '{img.id}': {e}")
results.append(DetectResult(
id=img.id,
bboxes=r["bboxes"],
labels=r["labels"],
width=r["width"],
height=r["height"],
inference_time_ms=round((time.time() - t0) * 1000, 2),
))
return DetectResponse(
results=results,
total_inference_time_ms=round((time.time() - total_start) * 1000, 2),
backend="qwen3-vl",
)
# Florence-2 path (default)
from .florence2_model import load, detect as f2detect
f2path = config.florence2_model_path
if not f2path:
raise HTTPException(status_code=503, detail="Florence-2 not configured (server started without --florence2-model-path)")
try:
load(f2path)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Florence-2 model load failed: {e}")
results: list[DetectResult] = []
total_start = time.time()
for img in request.images:
t0 = time.time()
try:
r = f2detect(img.data, task=request.task)
except Exception as e:
logger.error(f"Florence-2 inference error for '{img.id}': {e}")
raise HTTPException(status_code=500, detail=f"Inference failed for '{img.id}': {e}")
results.append(DetectResult(
id=img.id,
bboxes=r["bboxes"],
labels=r["labels"],
width=r["width"],
height=r["height"],
inference_time_ms=round((time.time() - t0) * 1000, 2),
))
return DetectResponse(
results=results,
total_inference_time_ms=round((time.time() - total_start) * 1000, 2),
)
@app.post("/shutdown")
async def shutdown():
logger.info("Shutdown requested via API")
from . import qwen3_vl_model
qwen3_vl_model.shutdown()
os.kill(os.getpid(), signal.SIGTERM)
return {"status": "shutting_down"}