Spaces:
Running
on
L40S
Running
on
L40S
| import os | |
| import torch | |
| from insightface.app import FaceAnalysis | |
| from insightface.utils import face_align | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| from curricularface import get_model | |
| import cv2 | |
| import numpy as np | |
| import numpy | |
| def matrix_sqrt(matrix): | |
| eigenvalues, eigenvectors = torch.linalg.eigh(matrix) | |
| sqrt_eigenvalues = torch.sqrt(torch.clamp(eigenvalues, min=0)) | |
| sqrt_matrix = (eigenvectors * sqrt_eigenvalues).mm(eigenvectors.T) | |
| return sqrt_matrix | |
| def sample_video_frames(video_path, num_frames=16): | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
| frames = [] | |
| for idx in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| # print(frame.shape) | |
| #if frame.shape[1] > 1024: | |
| # frame = frame[:, 1440:, :] | |
| # print(frame.shape) | |
| frames.append(frame) | |
| cap.release() | |
| return frames | |
| def get_face_keypoints(face_model, image_bgr): | |
| face_info = face_model.get(image_bgr) | |
| if len(face_info) > 0: | |
| return sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1] | |
| return None | |
| def load_image(image): | |
| img = image.convert('RGB') | |
| img = transforms.Resize((299, 299))(img) # Resize to Inception input size | |
| img = transforms.ToTensor()(img) | |
| return img.unsqueeze(0) # Add batch dimension | |
| def calculate_fid(real_activations, fake_activations, device="cuda"): | |
| real_activations_tensor = torch.tensor(real_activations).to(device) | |
| fake_activations_tensor = torch.tensor(fake_activations).to(device) | |
| mu1 = real_activations_tensor.mean(dim=0) | |
| sigma1 = torch.cov(real_activations_tensor.T) | |
| mu2 = fake_activations_tensor.mean(dim=0) | |
| sigma2 = torch.cov(fake_activations_tensor.T) | |
| ssdiff = torch.sum((mu1 - mu2) ** 2) | |
| covmean = matrix_sqrt(sigma1.mm(sigma2)) | |
| if torch.is_complex(covmean): | |
| covmean = covmean.real | |
| fid = ssdiff + torch.trace(sigma1 + sigma2 - 2 * covmean) | |
| return fid.item() | |
| def batch_cosine_similarity(embedding_image, embedding_frames, device="cuda"): | |
| embedding_image = torch.tensor(embedding_image).to(device) | |
| embedding_frames = torch.tensor(embedding_frames).to(device) | |
| return torch.nn.functional.cosine_similarity(embedding_image, embedding_frames, dim=-1).cpu().numpy() | |
| def get_activations(images, model, batch_size=16): | |
| model.eval() | |
| activations = [] | |
| with torch.no_grad(): | |
| for i in range(0, len(images), batch_size): | |
| batch = images[i:i + batch_size] | |
| pred = model(batch) | |
| activations.append(pred) | |
| activations = torch.cat(activations, dim=0).cpu().numpy() | |
| if activations.shape[0] == 1: | |
| activations = np.repeat(activations, 2, axis=0) | |
| return activations | |
| def pad_np_bgr_image(np_image, scale=1.25): | |
| assert scale >= 1.0, "scale should be >= 1.0" | |
| pad_scale = scale - 1.0 | |
| h, w = np_image.shape[:2] | |
| top = bottom = int(h * pad_scale) | |
| left = right = int(w * pad_scale) | |
| return cv2.copyMakeBorder(np_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)), (left, top) | |
| def process_image(face_model, image_path): | |
| if isinstance(image_path, str): | |
| np_faceid_image = np.array(Image.open(image_path).convert("RGB")) | |
| elif isinstance(image_path, numpy.ndarray): | |
| np_faceid_image = image_path | |
| else: | |
| raise TypeError("image_path should be a string or PIL.Image.Image object") | |
| image_bgr = cv2.cvtColor(np_faceid_image, cv2.COLOR_RGB2BGR) | |
| face_info = get_face_keypoints(face_model, image_bgr) | |
| if face_info is None: | |
| padded_image, sub_coord = pad_np_bgr_image(image_bgr) | |
| face_info = get_face_keypoints(face_model, padded_image) | |
| if face_info is None: | |
| print("Warning: No face detected in the image. Continuing processing...") | |
| return None, None | |
| face_kps = face_info['kps'] | |
| face_kps -= np.array(sub_coord) | |
| else: | |
| face_kps = face_info['kps'] | |
| arcface_embedding = face_info['embedding'] | |
| # print(face_kps) | |
| norm_face = face_align.norm_crop(image_bgr, landmark=face_kps, image_size=224) | |
| align_face = cv2.cvtColor(norm_face, cv2.COLOR_BGR2RGB) | |
| return align_face, arcface_embedding | |
| def inference(face_model, img, device): | |
| img = cv2.resize(img, (112, 112)) | |
| img = np.transpose(img, (2, 0, 1)) | |
| img = torch.from_numpy(img).unsqueeze(0).float().to(device) | |
| img.div_(255).sub_(0.5).div_(0.5) | |
| embedding = face_model(img).detach().cpu().numpy()[0] | |
| return embedding / np.linalg.norm(embedding) | |
| def process_video(video_path, face_arc_model, face_cur_model, fid_model, arcface_image_embedding, cur_image_embedding, real_activations, device): | |
| video_frames = sample_video_frames(video_path, num_frames=16) | |
| #print(video_frames) | |
| # Initialize lists to store the scores | |
| cur_scores = [] | |
| arc_scores = [] | |
| fid_face = [] | |
| for frame in video_frames: | |
| # Convert to RGB once at the beginning | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Process the frame for ArcFace embeddings | |
| align_face_frame, arcface_frame_embedding = process_image(face_arc_model, frame_rgb) | |
| # Skip if alignment fails | |
| if align_face_frame is None: | |
| continue | |
| # Perform inference for current face model | |
| cur_embedding_frame = inference(face_cur_model, align_face_frame, device) | |
| # Compute cosine similarity for cur_score and arc_score in a compact manner | |
| cur_score = max(0.0, batch_cosine_similarity(cur_image_embedding, cur_embedding_frame, device=device).item()) | |
| arc_score = max(0.0, batch_cosine_similarity(arcface_image_embedding, arcface_frame_embedding, device=device).item()) | |
| # Process FID score | |
| align_face_frame_pil = Image.fromarray(align_face_frame) | |
| fake_image = load_image(align_face_frame_pil).to(device) | |
| fake_activations = get_activations(fake_image, fid_model) | |
| fid_score = calculate_fid(real_activations, fake_activations, device) | |
| # Collect scores | |
| fid_face.append(fid_score) | |
| cur_scores.append(cur_score) | |
| arc_scores.append(arc_score) | |
| # Aggregate results with default values for empty lists | |
| avg_cur_score = np.mean(cur_scores) if cur_scores else 0.0 | |
| avg_arc_score = np.mean(arc_scores) if arc_scores else 0.0 | |
| avg_fid_score = np.mean(fid_face) if fid_face else 0.0 | |
| return avg_cur_score, avg_arc_score, avg_fid_score | |
| def main(): | |
| device = "cuda" | |
| # data_path = "data/SkyActor" | |
| # data_path = "data/LivePotraits" | |
| # data_path = "data/Actor-One" | |
| data_path = "data/FollowYourEmoji" | |
| img_path = "/maindata/data/shared/public/rui.wang/act_review/ref_images" | |
| pre_tag = False | |
| mp4_list = os.listdir(data_path) | |
| print(mp4_list) | |
| img_list = [] | |
| video_list = [] | |
| for mp4 in mp4_list: | |
| if "mp4" not in mp4: | |
| continue | |
| if pre_tag: | |
| png_path = mp4.split('.')[0].split('-')[0] + ".png" | |
| else: | |
| if "-" in mp4: | |
| png_path = mp4.split('.')[0].split('-')[1] + ".png" | |
| else: | |
| png_path = mp4.split('.')[0].split('_')[1] + ".png" | |
| img_list.append(os.path.join(img_path, png_path)) | |
| video_list.append(os.path.join(data_path, mp4)) | |
| print(img_list) | |
| print(video_list[0]) | |
| model_path = "eval" | |
| face_arc_path = os.path.join(model_path, "face_encoder") | |
| face_cur_path = os.path.join(face_arc_path, "glint360k_curricular_face_r101_backbone.bin") | |
| # Initialize FaceEncoder model for face detection and embedding extraction | |
| face_arc_model = FaceAnalysis(root=face_arc_path, providers=['CUDAExecutionProvider']) | |
| face_arc_model.prepare(ctx_id=0, det_size=(320, 320)) | |
| # Load face recognition model | |
| face_cur_model = get_model('IR_101')([112, 112]) | |
| face_cur_model.load_state_dict(torch.load(face_cur_path, map_location="cpu")) | |
| face_cur_model = face_cur_model.to(device) | |
| face_cur_model.eval() | |
| # Load InceptionV3 model for FID calculation | |
| fid_model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT) | |
| fid_model.fc = torch.nn.Identity() # Remove final classification layer | |
| fid_model.eval() | |
| fid_model = fid_model.to(device) | |
| # Process the single video and image pair | |
| # Extract embeddings and features from the image | |
| cur_list, arc_list, fid_list = [], [], [] | |
| for i in range(len(img_list)): | |
| align_face_image, arcface_image_embedding = process_image(face_arc_model, img_list[i]) | |
| cur_image_embedding = inference(face_cur_model, align_face_image, device) | |
| align_face_image_pil = Image.fromarray(align_face_image) | |
| real_image = load_image(align_face_image_pil).to(device) | |
| real_activations = get_activations(real_image, fid_model) | |
| # Process the video and calculate scores | |
| cur_score, arc_score, fid_score = process_video( | |
| video_list[i], face_arc_model, face_cur_model, fid_model, | |
| arcface_image_embedding, cur_image_embedding, real_activations, device | |
| ) | |
| print(cur_score, arc_score, fid_score) | |
| cur_list.append(cur_score) | |
| arc_list.append(arc_score) | |
| fid_list.append(fid_score) | |
| # break | |
| print("cur", sum(cur_list)/ len(cur_list)) | |
| print("arc", sum(arc_list)/ len(arc_list)) | |
| print("fid", sum(fid_list)/ len(fid_list)) | |
| main() | |