Project Files
src / fastvlm_server / model.py
"""Model backend for FastVLM with explicit backend selection."""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from PIL import Image
from .config import ServerConfig
logger = logging.getLogger(__name__)
class BaseModel(ABC):
"""Abstract base class for VLM backends."""
@abstractmethod
def generate(
self,
image: Image.Image,
prompt: str,
max_tokens: int = 256,
temperature: float = 0.7
) -> str:
"""Generate text from image + prompt."""
pass
class CoreMLVisionTower:
"""Wraps fastvithd.mlpackage for ANE-accelerated vision encoding."""
def __init__(self, mlpackage_path: str):
import coremltools as ct
# CPU_ONLY: avoids MLE5ExecutionStream (ANE) which releases Python objects
# on a background thread without holding the GIL → crash in Python 3.9.
# CPU keeps Metal GPU free for MLX language generation.
self._ct_model = ct.models.MLModel(
mlpackage_path,
compute_units=ct.ComputeUnit.CPU_ONLY,
)
logger.info(f"CoreML vision tower loaded (CPU_ONLY): {mlpackage_path}")
def predict(self, inputs_dict):
"""Called by patched fastvlm model: predict({"images": np_array}) → dict."""
return self._ct_model.predict(inputs_dict)
def __call__(self, x_nhwc):
import numpy as np
import mlx.core as mx
# x_nhwc: (B, H, W, C) bfloat16/float16 — CoreML wants NCHW float32
x_nchw = np.array(x_nhwc.transpose(0, 3, 1, 2).astype(mx.float32))
pred = self._ct_model.predict({"images": x_nchw})
# image_features: (1, 256, 3072) — reshape to (1, 16, 16, 3072)
# so fastvlm.py's B,H,W,C = shape / reshape(B,H*W,C) still works
features = pred["image_features"]
B, N, C = features.shape
s = int(N ** 0.5)
return None, mx.array(features.reshape(B, s, s, C)), None
class MLXModel(BaseModel):
"""FastVLM model wrapper with explicit backend mode (ane/mlx)."""
def __init__(self, model_path: str, backend: str):
from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config
self._generate_fn = generate
self._apply_chat_template = apply_chat_template
logger.info(f"Loading MLX model from {model_path}")
self.model, self.processor = load(model_path)
self.config = load_config(model_path)
# No silent fallback: backend is explicit and enforced.
if backend == "ane":
mlpackage = Path(model_path) / "fastvithd.mlpackage"
if not mlpackage.exists():
raise RuntimeError(
"Backend 'ane' requested but fastvithd.mlpackage is missing from model directory"
)
self.model.vision_tower = CoreMLVisionTower(str(mlpackage))
logger.info("ANE vision tower active (CoreML)")
elif backend == "mlx":
logger.info("MLX vision tower active (Metal)")
else:
raise RuntimeError(f"Unsupported backend: {backend}")
logger.info("MLX model loaded")
def generate(
self,
image: Image.Image,
prompt: str,
max_tokens: int = 256,
temperature: float = 0.7
) -> str:
formatted_prompt = self._apply_chat_template(
self.processor, self.config, prompt, num_images=1
)
output = self._generate_fn(
self.model,
self.processor,
formatted_prompt,
image,
max_tokens=max_tokens,
temp=temperature,
verbose=False
)
return output.text if hasattr(output, 'text') else str(output)
def get_model(config: ServerConfig) -> BaseModel:
"""Model factory with explicit backend selection (no fallback)."""
try:
return MLXModel(config.model_path, config.resolved_backend)
except Exception as e:
logger.error(f"MLX load failed: {e}")
raise RuntimeError(
f"MLX model load failed: {e}\n\n"
f"This usually means the model at '{config.model_path}' is not an MLX-compatible model."
) from e