RedFish / app.py
hongyu12321's picture
Upload 3 files
df3ab33 verified
raw
history blame
1.14 kB
# app.py
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import gradio as gr
from PIL import Image
from hf_model import PretrainedAgeEstimator
est = PretrainedAgeEstimator()
def predict(img):
# Gradio may pass PIL or numpy; handle both
if not isinstance(img, Image.Image):
img = Image.fromarray(img)
age, top = est.predict(img, topk=5)
# 1) dict[str, float] for Label
probs = {lbl: float(prob) for lbl, prob in top}
# 2) plain string for the estimate
summary = f"Estimated age: **{age:.1f}** years"
return probs, summary
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload a face image"),
outputs=[
gr.Label(num_top_classes=5, label="Age Prediction (probabilities)"),
gr.Markdown(label="Summary"),
],
title="Pretrained Age Estimator",
description="Runs a pretrained ViT-based age classifier and reports a point estimate from class probabilities."
)
if __name__ == "__main__":
demo.launch(share=True)