# app.py — Compact UI: Age-first + FAST cartoon (Turbo) with collapsible advanced options import os os.environ["TRANSFORMERS_NO_TF"] = "1" os.environ["TRANSFORMERS_NO_FLAX"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" from typing import Optional import gradio as gr from PIL import Image, ImageDraw import numpy as np import torch # ------------------ Age estimator (Hugging Face) ------------------ from transformers import AutoImageProcessor, AutoModelForImageClassification HF_MODEL_ID = "nateraw/vit-age-classifier" AGE_RANGE_TO_MID = { "0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35, "40-49": 45, "50-59": 55, "60-69": 65, "70+": 75 } class PretrainedAgeEstimator: def __init__(self, model_id: str = HF_MODEL_ID, device: Optional[str] = None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True) self.model = AutoModelForImageClassification.from_pretrained(model_id) self.model.to(self.device).eval() self.id2label = self.model.config.id2label @torch.inference_mode() def predict(self, img: Image.Image, topk: int = 5): if img.mode != "RGB": img = img.convert("RGB") inputs = self.processor(images=img, return_tensors="pt").to(self.device) logits = self.model(**inputs).logits probs = logits.softmax(dim=-1).squeeze(0) k = min(topk, probs.numel()) values, indices = torch.topk(probs, k=k) top = [(self.id2label[i.item()], float(v.item())) for i, v in zip(indices, values)] expected = sum(AGE_RANGE_TO_MID.get(self.id2label[i], 35) * float(p) for i, p in enumerate(probs)) return expected, top # ------------------ Largest-face detector with nice margin ------------------ from facenet_pytorch import MTCNN class FaceCropper: """Detect faces; return (wide_crop, annotated). Largest face only; adds margin so face isn't full screen.""" def __init__(self, device: Optional[str] = None, margin_scale: float = 1.85): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.mtcnn = MTCNN(keep_all=True, device=self.device) self.margin_scale = margin_scale def _ensure_pil(self, img): if isinstance(img, Image.Image): return img.convert("RGB") return Image.fromarray(img).convert("RGB") def detect_and_crop_wide(self, img): pil = self._ensure_pil(img) W, H = pil.size boxes, probs = self.mtcnn.detect(pil) annotated = pil.copy() draw = ImageDraw.Draw(annotated) if boxes is None or len(boxes) == 0: return None, annotated # draw all boxes for b, p in zip(boxes, probs): bx1, by1, bx2, by2 = map(float, b) draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3) draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0)) # choose largest idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes])) x1, y1, x2, y2 = boxes[idx] # expand with margin (approx 4:5 portrait) cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0 w, h = (x2 - x1), (y2 - y1) side = max(w, h) * self.margin_scale target_w = side target_h = side * 1.25 nx1 = int(max(0, cx - target_w/2)) nx2 = int(min(W, cx + target_w/2)) ny1 = int(max(0, cy - target_h/2)) ny2 = int(min(H, cy + target_h/2)) crop = pil.crop((nx1, ny1, nx2, ny2)) return crop, annotated # ------------------ Fast Cartoonizer (SD-Turbo) with safety ------------------ from diffusers import AutoPipelineForImage2Image from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor TURBO_ID = "stabilityai/sd-turbo" def load_turbo_pipe(device): dtype = torch.float16 if torch.cuda.is_available() else torch.float32 pipe = AutoPipelineForImage2Image.from_pretrained( TURBO_ID, dtype=dtype, # ✅ no deprecation warning ).to(device) # Safety checker ON for public Spaces pipe.safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker" ) pipe.feature_extractor = AutoFeatureExtractor.from_pretrained( "CompVis/stable-diffusion-safety-checker" ) try: pipe.enable_attention_slicing() except Exception: pass return pipe # ------------------ Init models once ------------------ age_est = PretrainedAgeEstimator() cropper = FaceCropper(device=age_est.device, margin_scale=1.85) sd_pipe = load_turbo_pipe(age_est.device) # ------------------ Hint choices (with defaults) ------------------ ROLE_CHOICES = [ "Queen/Princess", "King/Prince", "Fairy", "Elf", "Knight", "Sorcerer/Sorceress", "Steampunk Royalty", "Cyberpunk Royalty", "Superhero", "Anime Protagonist" ] BACKGROUND_CHOICES = [ "grand castle hall", "castle balcony at sunset", "enchanted forest", "starry night sky", "throne room with banners", "crystal palace", "moonlit garden", "winter snow castle", "golden hour meadow", "mystical waterfall" ] LIGHTING_CHOICES = [ "soft magical lighting", "golden hour rim light", "cinematic soft light", "glowing ambience", "volumetric light rays", "dramatic chiaroscuro" ] ARTSTYLE_CHOICES = [ "Disney/Pixar style", "Studio Ghibli watercolor", "cel-shaded cartoon", "storybook illustration", "painterly brush strokes", "anime lineart" ] COLOR_CHOICES = [ "pastel palette", "vibrant colors", "warm tones", "cool tones", "iridescent highlights", "royal gold & sapphire" ] OUTFIT_CHOICES = [ "elegant gown", "ornate royal cloak", "jeweled tiara/crown", "silver diadem", "flowing cape", "intricate embroidery" ] EFFECTS_CHOICES = [ "sparkles", "soft bokeh background", "floating petals", "glowing particles", "butterflies", "magical aura" ] NEGATIVE_PROMPT = ( "deformed, disfigured, ugly, extra limbs, extra fingers, bad anatomy, low quality, blurry, watermark, text, logo" ) # ------------------ Helpers ------------------ def _ensure_pil(img): return img if isinstance(img, Image.Image) else Image.fromarray(img) def _resize_512(im: Image.Image): w, h = im.size scale = 512 / max(w, h) if scale < 1.0: im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS) return im def build_prompt(role, background, lighting, artstyle, colors, outfit, effects, extra): """Defaults always exist; user selections override them.""" # Defaults (applied if user doesn't choose) role = role or "Queen/Princess" background = background or ["castle balcony at sunset"] lighting = lighting or ["soft magical lighting"] artstyle = artstyle or ["storybook illustration"] colors = colors or ["vibrant colors"] outfit = outfit or ["elegant gown", "jeweled tiara/crown"] effects = effects or ["sparkles", "glowing particles"] role_map = { "Queen/Princess": "regal queen/princess portrait", "King/Prince": "regal king/prince portrait", "Fairy": "ethereal fairy portrait with delicate wings", "Elf": "elven royalty portrait with elegant ears", "Knight": "valiant knight portrait in ornate armor", "Sorcerer/Sorceress": "mystical sorcerer portrait with arcane motifs", "Steampunk Royalty": "steampunk royal portrait with brass filigree", "Cyberpunk Royalty": "cyberpunk royal portrait with neon accents", "Superhero": "heroic comic-style portrait", "Anime Protagonist": "anime protagonist portrait", } parts = [role_map.get(role, role)] for group in (background, lighting, artstyle, colors, outfit, effects): if group and isinstance(group, list): parts.append(", ".join(group)) parts.append("clean lineart, high quality") extra = (extra or "").strip() if extra: parts.append(extra) return ", ".join([p for p in parts if p]) # ------------------ Actions ------------------ @torch.inference_mode() def predict_age_only(img, auto_crop=True): if img is None: return {}, "Please upload an image.", None pil = _ensure_pil(img).convert("RGB") face_wide, annotated = (None, None) if auto_crop: face_wide, annotated = cropper.detect_and_crop_wide(pil) target = face_wide if face_wide is not None else pil age, top = age_est.predict(target, topk=5) probs = {lbl: float(p) for lbl, p in top} summary = f"**Estimated age:** {age:.1f} years" return probs, summary, (annotated if annotated is not None else pil) @torch.inference_mode() def generate_cartoon(img, role, background, lighting, artstyle, colors, outfit, effects, extra_desc, auto_crop=True, strength=0.5, steps=2, seed=-1): if img is None: return None pil = _ensure_pil(img).convert("RGB") if auto_crop: face_wide, _ = cropper.detect_and_crop_wide(pil) if face_wide is not None: pil = face_wide pil = _resize_512(pil) prompt = build_prompt(role, background, lighting, artstyle, colors, outfit, effects, extra_desc) generator = None if isinstance(seed, (int, float)) and int(seed) >= 0: generator = torch.Generator(device=age_est.device).manual_seed(int(seed)) out = sd_pipe( prompt=prompt, negative_prompt=NEGATIVE_PROMPT, image=pil, strength=float(strength), # 0.4–0.6 keeps identity & adds dress/background guidance_scale=0.0, # Turbo likes 0 num_inference_steps=int(steps),# 1–4 → fast generator=generator, ) return out.images[0] # ------------------ Compact UI ------------------ with gr.Blocks(title="Age + Cartoon (Compact)") as demo: gr.Markdown("## Upload → Predict Age → Make Cartoon ✨") gr.Markdown("Largest face is used if multiple people are present. Defaults are applied automatically.") with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam") auto = gr.Checkbox(True, label="Auto face crop (recommended)") # Buttons visible immediately (no scrolling) with gr.Row(): btn_age = gr.Button("Predict Age", variant="primary") btn_cartoon = gr.Button("Make Cartoon", variant="secondary") # Collapsible advanced options with gr.Accordion("🎨 Advanced Cartoon Options", open=False): role = gr.Dropdown(choices=ROLE_CHOICES, value="Queen/Princess", label="Role") background = gr.CheckboxGroup(choices=BACKGROUND_CHOICES, value=["castle balcony at sunset"], label="Background") lighting = gr.CheckboxGroup(choices=LIGHTING_CHOICES, value=["soft magical lighting"], label="Lighting") artstyle = gr.CheckboxGroup(choices=ARTSTYLE_CHOICES, value=["storybook illustration"], label="Art Style") colors = gr.CheckboxGroup(choices=COLOR_CHOICES, value=["vibrant colors"], label="Color Mood") outfit = gr.CheckboxGroup(choices=OUTFIT_CHOICES, value=["elegant gown", "jeweled tiara/crown"], label="Outfit / Accessories") effects = gr.CheckboxGroup(choices=EFFECTS_CHOICES, value=["sparkles", "glowing particles"], label="Magical Effects") extra = gr.Textbox(label="Extra description (optional)", placeholder="e.g., silver tiara, flowing gown, balcony at sunset") with gr.Row(): strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength") steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)") seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)") with gr.Column(scale=1): probs_out = gr.Label(num_top_classes=5, label="Age Prediction") age_md = gr.Markdown(label="Age Summary") preview = gr.Image(label="Detection Preview") cartoon_out = gr.Image(label="Cartoon Result") # Wire events btn_age.click(fn=predict_age_only, inputs=[img_in, auto], outputs=[probs_out, age_md, preview]) btn_cartoon.click( fn=generate_cartoon, inputs=[img_in, role, background, lighting, artstyle, colors, outfit, effects, extra, auto, strength, steps, seed], outputs=cartoon_out ) # Expose for HF Spaces app = demo if __name__ == "__main__": app.queue().launch()