FaceLift / app.py
wlyu-adobe's picture
Initial commit
133857a
raw
history blame
16.9 kB
# Copyright (C) 2025, FaceLift Research Group
# https://github.com/weijielyu/FaceLift
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact: [email protected]
"""
FaceLift: Single Image 3D Face Reconstruction
Generates 3D head models from single images using multi-view diffusion and GS-LRM.
"""
import json
from pathlib import Path
from datetime import datetime
import gradio as gr
import numpy as np
import torch
import yaml
from easydict import EasyDict as edict
from einops import rearrange
from PIL import Image
from huggingface_hub import snapshot_download
import spaces
# Install diff-gaussian-rasterization at runtime (requires GPU)
import subprocess
import sys
import os
try:
import diff_gaussian_rasterization
except ImportError:
print("Installing diff-gaussian-rasterization...")
# Set CUDA architecture for A100 (compute capability 8.0)
# This prevents the IndexError in _get_cuda_arch_flags
env = os.environ.copy()
env["TORCH_CUDA_ARCH_LIST"] = "8.0"
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"git+https://github.com/graphdeco-inria/diff-gaussian-rasterization"
], env=env)
import diff_gaussian_rasterization
from gslrm.model.gaussians_renderer import render_turntable, imageseq2video
from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping
# Import convert function for PLY to SPLAT conversion
import convert as ply_to_splat
# HuggingFace repository configuration
HF_REPO_ID = "wlyu/OpenFaceLift"
def download_weights_from_hf() -> Path:
"""Download model weights from HuggingFace if not already present.
Returns:
Path to the downloaded repository
"""
workspace_dir = Path(__file__).parent
# Check if weights already exist locally
mvdiffusion_path = workspace_dir / "checkpoints/mvdiffusion/pipeckpts"
gslrm_path = workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt"
if mvdiffusion_path.exists() and gslrm_path.exists():
print("Using local model weights")
return workspace_dir
print(f"Downloading model weights from HuggingFace: {HF_REPO_ID}")
print("This may take a few minutes on first run...")
# Download to local directory
snapshot_download(
repo_id=HF_REPO_ID,
local_dir=str(workspace_dir / "checkpoints"),
local_dir_use_symlinks=False,
)
print("Model weights downloaded successfully!")
return workspace_dir
class FaceLiftPipeline:
"""Pipeline for FaceLift 3D head generation from single images."""
def __init__(self):
# Download weights from HuggingFace if needed
workspace_dir = download_weights_from_hf()
# Setup paths
self.output_dir = workspace_dir / "outputs"
self.examples_dir = workspace_dir / "examples"
self.output_dir.mkdir(exist_ok=True)
# Parameters
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.image_size = 512
self.camera_indices = [2, 1, 0, 5, 4, 3]
# Load models (keep on CPU for ZeroGPU compatibility)
print("Loading models...")
self.mvdiffusion_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
str(workspace_dir / "checkpoints/mvdiffusion/pipeckpts"),
torch_dtype=torch.float16,
)
# Don't move to device or enable xformers here - will be done in GPU-decorated function
self._models_on_gpu = False
with open(workspace_dir / "configs/gslrm.yaml", "r") as f:
config = edict(yaml.safe_load(f))
module_name, class_name = config.model.class_name.rsplit(".", 1)
module = __import__(module_name, fromlist=[class_name])
ModelClass = getattr(module, class_name)
self.gs_lrm_model = ModelClass(config)
checkpoint = torch.load(
workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt",
map_location="cpu"
)
self.gs_lrm_model.load_state_dict(checkpoint["model"])
# Keep on CPU initially - will move to GPU in decorated function
self.color_prompt_embedding = torch.load(
workspace_dir / "mvdiffusion/fixed_prompt_embeds_6view/clr_embeds.pt",
map_location="cpu"
)
with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f:
self.cameras_data = json.load(f)["frames"]
print("Models loaded successfully!")
def _move_models_to_gpu(self):
"""Move models to GPU and enable optimizations. Called within @spaces.GPU context."""
if not self._models_on_gpu and torch.cuda.is_available():
print("Moving models to GPU...")
self.device = torch.device("cuda:0")
self.mvdiffusion_pipeline.to(self.device)
self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention()
self.gs_lrm_model.to(self.device)
self.color_prompt_embedding = self.color_prompt_embedding.to(self.device)
self._models_on_gpu = True
print("Models on GPU, xformers enabled!")
def _create_viewer_html(self, splat_path):
"""Create standalone HTML viewer for the gaussian splat."""
import base64
# Read the splat file and encode as base64
with open(splat_path, 'rb') as f:
splat_data = f.read()
splat_b64 = base64.b64encode(splat_data).decode('utf-8')
# Read the main.js content and modify it to use embedded data
with open(Path(__file__).parent / "main.js", 'r') as f:
js_content = f.read()
# Replace the URL fetching part with blob URL
js_content = js_content.replace(
'params.get("url") || "train.splat"',
'window.EMBEDDED_SPLAT_URL || "train.splat"'
)
html = f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>3D Gaussian Splat Viewer - FaceLift</title>
<style>
body {{
margin: 0;
padding: 0;
overflow: hidden;
font-family: Arial, sans-serif;
background: #000;
}}
#canvas {{
width: 100vw;
height: 100vh;
display: block;
}}
#info {{
position: absolute;
top: 10px;
left: 10px;
background: rgba(0, 0, 0, 0.7);
color: white;
padding: 10px;
border-radius: 5px;
font-size: 12px;
max-width: 300px;
z-index: 1000;
}}
#info h3 {{
margin: 0 0 10px 0;
font-size: 14px;
}}
#loading {{
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
color: white;
font-size: 18px;
text-align: center;
z-index: 999;
}}
</style>
</head>
<body>
<canvas id="canvas"></canvas>
<div id="loading">Loading 3D model...</div>
<div id="info">
<h3>Controls</h3>
<p><b>Mouse:</b> Click and drag to orbit</p>
<p><b>Right click/Ctrl+drag:</b> Move forward/back (up/down), strafe (left/right)</p>
<p><b>Arrow keys:</b> Move forward/back, strafe left/right</p>
<p><b>WASD:</b> Rotate camera</p>
<p><b>Space:</b> Jump</p>
<p><b>Q/E:</b> Roll camera</p>
</div>
<script>
// Convert base64 to blob and create URL
const SPLAT_DATA_B64 = '{splat_b64}';
function b64toBlob(b64Data, contentType='application/octet-stream', sliceSize=512) {{
const byteCharacters = atob(b64Data);
const byteArrays = [];
for (let offset = 0; offset < byteCharacters.length; offset += sliceSize) {{
const slice = byteCharacters.slice(offset, offset + sliceSize);
const byteNumbers = new Array(slice.length);
for (let i = 0; i < slice.length; i++) {{
byteNumbers[i] = slice.charCodeAt(i);
}}
const byteArray = new Uint8Array(byteNumbers);
byteArrays.push(byteArray);
}}
return new Blob(byteArrays, {{type: contentType}});
}}
// Create blob URL for embedded splat data
const blob = b64toBlob(SPLAT_DATA_B64);
window.EMBEDDED_SPLAT_URL = URL.createObjectURL(blob);
// Remove loading message once render starts
setTimeout(() => {{
const loading = document.getElementById('loading');
if (loading) loading.style.display = 'none';
}}, 2000);
</script>
<script>
{js_content}
</script>
</body>
</html>"""
return html
@spaces.GPU
def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0,
random_seed=4, num_steps=50):
"""Generate 3D head from single image."""
# Move models to GPU now that we're in the GPU context
self._move_models_to_gpu()
try:
# Setup output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = self.output_dir / timestamp
output_dir.mkdir(exist_ok=True)
# Preprocess input
original_img = np.array(Image.open(image_path))
input_image = preprocess_image(original_img) if auto_crop else \
preprocess_image_without_cropping(original_img)
if input_image.size != (self.image_size, self.image_size):
input_image = input_image.resize((self.image_size, self.image_size))
input_path = output_dir / "input.png"
input_image.save(input_path)
# Generate multi-view images
generator = torch.Generator(device=self.mvdiffusion_pipeline.unet.device)
generator.manual_seed(random_seed)
result = self.mvdiffusion_pipeline(
input_image, None,
prompt_embeds=self.color_prompt_embedding,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
num_inference_steps=num_steps,
generator=generator,
eta=1.0,
)
selected_views = result.images[:6]
# Save multi-view composite
multiview_image = Image.new("RGB", (self.image_size * 6, self.image_size))
for i, view in enumerate(selected_views):
multiview_image.paste(view, (self.image_size * i, 0))
multiview_path = output_dir / "multiview.png"
multiview_image.save(multiview_path)
# Prepare 3D reconstruction input
view_arrays = [np.array(view) for view in selected_views]
lrm_input = torch.from_numpy(np.stack(view_arrays, axis=0)).float()
lrm_input = lrm_input[None].to(self.device) / 255.0
lrm_input = rearrange(lrm_input, "b v h w c -> b v c h w")
# Prepare camera parameters
selected_cameras = [self.cameras_data[i] for i in self.camera_indices]
fxfycxcy_list = [[c["fx"], c["fy"], c["cx"], c["cy"]] for c in selected_cameras]
c2w_list = [np.linalg.inv(np.array(c["w2c"])) for c in selected_cameras]
fxfycxcy = torch.from_numpy(np.stack(fxfycxcy_list, axis=0).astype(np.float32))
c2w = torch.from_numpy(np.stack(c2w_list, axis=0).astype(np.float32))
fxfycxcy = fxfycxcy[None].to(self.device)
c2w = c2w[None].to(self.device)
batch_indices = torch.stack([
torch.zeros(lrm_input.size(1)).long(),
torch.arange(lrm_input.size(1)).long(),
], dim=-1)[None].to(self.device)
batch = edict({
"image": lrm_input,
"c2w": c2w,
"fxfycxcy": fxfycxcy,
"index": batch_indices,
})
# Run 3D reconstruction
with torch.autocast(enabled=True, device_type="cuda", dtype=torch.float16):
result = self.gs_lrm_model.forward(batch, create_visual=False, split_data=True)
comp_image = result.render[0].unsqueeze(0).detach()
gaussians = result.gaussians[0]
# Save filtered gaussians
filtered_gaussians = gaussians.apply_all_filters(
cam_origins=None,
opacity_thres=0.04,
scaling_thres=0.2,
floater_thres=0.75,
crop_bbx=[-0.91, 0.91, -0.91, 0.91, -1.0, 1.0],
nearfar_percent=(0.0001, 1.0),
)
ply_path = output_dir / "gaussians.ply"
filtered_gaussians.save_ply(str(ply_path))
# Convert PLY to SPLAT format for web viewer
splat_path = output_dir / "gaussians.splat"
ply_to_splat.convert(str(ply_path), str(splat_path))
# Create HTML viewer
viewer_html = self._create_viewer_html(str(splat_path))
viewer_path = output_dir / "viewer.html"
with open(viewer_path, 'w') as f:
f.write(viewer_html)
# Save output image
comp_image = rearrange(comp_image, "x v c h w -> (x h) (v w) c")
comp_image = (comp_image.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
output_path = output_dir / "output.png"
Image.fromarray(comp_image).save(output_path)
# Generate turntable video
turntable_frames = render_turntable(gaussians, rendering_resolution=self.image_size,
num_views=180)
turntable_frames = rearrange(turntable_frames, "h (v w) c -> v h w c", v=180)
turntable_frames = np.ascontiguousarray(turntable_frames)
turntable_path = output_dir / "turntable.mp4"
imageseq2video(turntable_frames, str(turntable_path), fps=30)
return str(input_path), str(multiview_path), str(output_path), \
str(turntable_path), str(viewer_path), str(ply_path)
except Exception as e:
raise gr.Error(f"Generation failed: {str(e)}")
def main():
"""Run the FaceLift application."""
pipeline = FaceLiftPipeline()
# Load examples
examples = []
if pipeline.examples_dir.exists():
examples = [[str(f)] for f in sorted(pipeline.examples_dir.iterdir())
if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}]
# Create interface
demo = gr.Interface(
fn=pipeline.generate_3d_head,
title="FaceLift: Single Image 3D Face Reconstruction",
description="""
Transform a single portrait image into a complete 3D head model with an interactive WebGL viewer.
**Tips:**
- Use high-quality portrait images with clear facial features
- If face detection fails, try disabling auto-cropping and manually crop to square
- Download the Interactive 3D Viewer HTML file to explore your model in full screen
""",
inputs=[
gr.Image(type="filepath", label="Input Portrait Image"),
gr.Checkbox(value=True, label="Auto Cropping"),
gr.Slider(1.0, 10.0, 3.0, step=0.1, label="Guidance Scale"),
gr.Number(value=4, label="Random Seed"),
gr.Slider(10, 100, 50, step=5, label="Generation Steps"),
],
outputs=[
gr.Image(label="Processed Input"),
gr.Image(label="Multi-view Generation"),
gr.Image(label="3D Reconstruction"),
gr.PlayableVideo(label="Turntable Animation"),
gr.File(label="Interactive 3D Viewer (.html) - Download & Open"),
gr.File(label="3D Model (.ply)"),
],
examples=examples,
allow_flagging="never",
)
demo.queue(max_size=10)
demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
if __name__ == "__main__":
main()