florence-2 / app.py
Naman712's picture
Update app.py
7eddc37 verified
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import HTMLResponse
from transformers import AutoProcessor, Florence2ForConditionalGeneration # <--- DIRECT IMPORT
from PIL import Image
import torch
import io
app = FastAPI()
print("⏳ Initializing Florence-2 (Hardcoded Class Mode)...")
# We use the community fork for clean config
model_id = "florence-community/Florence-2-large"
device = "cpu"
try:
# 1. Load Processor
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=False)
# 2. Load Model using the SPECIFIC CLASS (No "AutoModel" guessing)
model = Florence2ForConditionalGeneration.from_pretrained(
model_id,
trust_remote_code=False,
torch_dtype=torch.float32
).to(device)
print("βœ… Model Loaded Successfully!")
except Exception as e:
print(f"❌ Load Error: {e}")
model = None
processor = None
# --- UI ---
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Naman's AI Vision</title>
<style>
body { font-family: sans-serif; background: #0f172a; color: white; display: flex; flex-direction: column; align-items: center; min-height: 100vh; margin: 0; padding: 20px; }
.box { background: #1e293b; padding: 30px; border-radius: 15px; width: 100%; max-width: 600px; text-align: center; border: 1px solid #334155; }
h1 { margin-top: 0; color: #38bdf8; }
button { background: #38bdf8; color: #000; border: none; padding: 10px 20px; border-radius: 5px; font-weight: bold; cursor: pointer; margin-top: 10px; }
button:disabled { opacity: 0.5; }
#result { margin-top: 20px; white-space: pre-wrap; text-align: left; background: #000; padding: 15px; border-radius: 5px; font-family: monospace; display: none; }
img { max-width: 100%; border-radius: 10px; margin-top: 10px; display: none; }
</style>
</head>
<body>
<div class="box">
<h1>πŸ‘οΈ Florence-2 Vision AI</h1>
<p>Advanced OCR & Image Understanding (CPU)</p>
<input type="file" id="file" accept="image/*" style="display: none;">
<button onclick="document.getElementById('file').click()">πŸ“‚ Upload Image</button>
<br><br>
<select id="task" style="padding: 10px; border-radius: 5px;">
<option value="<OCR>">πŸ“„ Read Text (OCR)</option>
<option value="<CAPTION>">πŸ–ΌοΈ Describe Image</option>
<option value="<OD>">πŸ“¦ Detect Objects</option>
</select>
<button onclick="runAI()" id="runBtn">Run AI</button>
<img id="preview">
<div id="result"></div>
</div>
<script>
const fileInput = document.getElementById('file');
const preview = document.getElementById('preview');
const result = document.getElementById('result');
const runBtn = document.getElementById('runBtn');
let currentFile = null;
fileInput.addEventListener('change', (e) => {
currentFile = e.target.files[0];
preview.src = URL.createObjectURL(currentFile);
preview.style.display = 'block';
result.style.display = 'none';
});
async function runAI() {
if (!currentFile) return alert("Select an image first!");
runBtn.innerText = "Processing...";
runBtn.disabled = true;
result.style.display = 'none';
const formData = new FormData();
formData.append('file', currentFile);
formData.append('task_prompt', document.getElementById('task').value);
try {
const res = await fetch('/analyze', { method: 'POST', body: formData });
const data = await res.json();
result.innerText = data.result || "Error: " + JSON.stringify(data);
result.style.display = 'block';
} catch (e) {
alert("Error: " + e);
}
runBtn.innerText = "Run AI";
runBtn.disabled = false;
}
</script>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse)
def home(): return html_content
@app.post("/analyze")
async def analyze(task_prompt: str = "<OCR>", file: UploadFile = File(...)):
if not model: return {"error": "Model failed to load"}
try:
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
inputs = processor(text=task_prompt, images=img, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=False
)
text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed = processor.post_process_generation(text, task=task_prompt, image_size=img.size)
return {"result": str(parsed)}
except Exception as e:
return {"error": str(e)}