Project Files
src / fastvlm_server / florence2_model.py
"""Florence-2 object detection model (lazy-loaded, mlx-community port)."""
from __future__ import annotations
import base64
import io
from typing import Optional
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
_processor: Optional[object] = None
_model: Optional[object] = None
_model_path: Optional[str] = None
def _patch_florence2_source(model_path: str) -> None:
"""Patch Florence-2 remote code files in the model directory.
from_pretrained(trust_remote_code=True, local path) always re-copies model
Python files from the local directory to the HF modules cache, overwriting
any in-place cache edits. Patching the source files is the only reliable fix.
After patching, the HF modules cache for Florence-2 is deleted so that
from_pretrained re-copies the now-correct sources on the next call.
"""
import pathlib, shutil
patches = [
# configuration_florence2.py: forced_bos_token_id accessed before parent __init__
(
"configuration_florence2.py",
"if self.forced_bos_token_id is None",
"if getattr(self, 'forced_bos_token_id', None) is None",
),
# processing_florence2.py: additional_special_tokens not accessible on RobertaTokenizer
(
"processing_florence2.py",
"tokenizer.additional_special_tokens + \\",
"getattr(tokenizer, 'additional_special_tokens', []) + \\",
),
# modeling_florence2.py: _supports_sdpa accessed before language_model is initialized
(
"modeling_florence2.py",
"SDPA or not.\n \"\"\"\n return self.language_model._supports_sdpa",
"SDPA or not.\n \"\"\"\n if not hasattr(self, 'language_model'):\n return False\n return self.language_model._supports_sdpa",
),
]
patched_any = False
base = pathlib.Path(model_path)
for filename, old, new in patches:
src = base / filename
if not src.exists():
continue
text = src.read_text(encoding="utf-8")
if old not in text:
continue # already patched
src.write_text(text.replace(old, new), encoding="utf-8")
patched_any = True
if not patched_any:
return
# Remove the stale HF cache so from_pretrained re-copies the patched sources.
cache_base = (
pathlib.Path.home()
/ ".cache" / "huggingface" / "modules" / "transformers_modules"
)
for d in cache_base.glob("*lorence*2*"):
try:
shutil.rmtree(str(d))
except OSError:
pass
def load(model_path: str) -> None:
global _processor, _model, _model_path
if _model_path == model_path and _model is not None:
return
_patch_florence2_source(model_path)
_processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True
)
_model_path = model_path
def detect(image_b64: str, task: str = "<OD>") -> dict:
"""
task: "<OD>" for open-vocabulary detection,
"<OPEN_VOCABULARY_DETECTION>text" for filtered detection,
"<CAPTION_TO_PHRASE_GROUNDING>sentence" for grounded captions.
Returns: {"bboxes": [[x1,y1,x2,y2], ...], "labels": [...], "width": W, "height": H}
"""
img_data = base64.b64decode(image_b64)
image = Image.open(io.BytesIO(img_data)).convert("RGB")
W, H = image.size
inputs = _processor(text=task, images=image, return_tensors="pt")
generated_ids = _model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
use_cache=False,
)
generated_text = _processor.batch_decode(
generated_ids, skip_special_tokens=False
)[0]
parsed = _processor.post_process_generation(
generated_text,
task=task,
image_size=(W, H),
)
# post_process_generation may return a plain string for tasks it cannot
# post-process (e.g. an unrecognised task token).
if not isinstance(parsed, dict):
return {"bboxes": [], "labels": [], "width": W, "height": H}
# The result is keyed by the task token prefix only (e.g.
# "<OPEN_VOCABULARY_DETECTION>"), not by the full task string that may
# include a text query (e.g. "<OPEN_VOCABULARY_DETECTION>lamp").
raw = parsed.get(task)
if raw is None:
import re
m = re.match(r"(<[^>]+>)", task)
token = m.group(1) if m else task
raw = parsed.get(token, {})
if not isinstance(raw, dict):
return {"bboxes": [], "labels": [], "width": W, "height": H}
bboxes = raw.get("bboxes", [])
# <OD> uses "labels"; <OPEN_VOCABULARY_DETECTION> uses "bboxes_labels".
labels = raw.get("labels") or raw.get("bboxes_labels", [])
return {"bboxes": bboxes, "labels": labels, "width": W, "height": H}