|
|
import torch |
|
|
import torchvision |
|
|
import torchvision.transforms as transforms |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import gradio as gr |
|
|
import sys |
|
|
import os |
|
|
import tqdm |
|
|
sys.path.append(os.path.abspath(os.path.join("", ".."))) |
|
|
import torch |
|
|
import gc |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
from PIL import Image |
|
|
from utils import load_models, save_model_w2w, save_model_for_diffusers |
|
|
from editing import get_direction, debias |
|
|
from sampling import sample_weights |
|
|
from lora_w2w import LoRAw2w |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
global device |
|
|
global generator |
|
|
global unet |
|
|
global vae |
|
|
global text_encoder |
|
|
global tokenizer |
|
|
global noise_scheduler |
|
|
global network |
|
|
device = "cuda:0" |
|
|
generator = torch.Generator(device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models_path = snapshot_download(repo_id="Snapchat/w2w") |
|
|
|
|
|
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device) |
|
|
std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device) |
|
|
v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device) |
|
|
proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device) |
|
|
df = torch.load(f"{models_path}/files/identity_df.pt") |
|
|
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") |
|
|
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device) |
|
|
|
|
|
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) |
|
|
|
|
|
|
|
|
def sample_model(): |
|
|
global unet |
|
|
del unet |
|
|
global network |
|
|
|
|
|
unet, _, _, _, _ = load_models(device) |
|
|
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed): |
|
|
global device |
|
|
global generator |
|
|
global unet |
|
|
global vae |
|
|
global text_encoder |
|
|
global tokenizer |
|
|
global noise_scheduler |
|
|
generator = generator.manual_seed(seed) |
|
|
latents = torch.randn( |
|
|
(1, unet.in_channels, 512 // 8, 512 // 8), |
|
|
generator = generator, |
|
|
device = device |
|
|
).bfloat16() |
|
|
|
|
|
|
|
|
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
|
|
|
|
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] |
|
|
|
|
|
max_length = text_input.input_ids.shape[-1] |
|
|
uncond_input = tokenizer( |
|
|
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" |
|
|
) |
|
|
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] |
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
noise_scheduler.set_timesteps(ddim_steps) |
|
|
latents = latents * noise_scheduler.init_noise_sigma |
|
|
|
|
|
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): |
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) |
|
|
with network: |
|
|
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
latents = 1 / 0.18215 * latents |
|
|
image = vae.decode(latents).sample |
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] |
|
|
|
|
|
image = Image.fromarray((image * 255).round().astype("uint8")) |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): |
|
|
|
|
|
global device |
|
|
global generator |
|
|
global unet |
|
|
global vae |
|
|
global text_encoder |
|
|
global tokenizer |
|
|
global noise_scheduler |
|
|
global young |
|
|
global pointy |
|
|
global wavy |
|
|
global large |
|
|
|
|
|
original_weights = network.proj.clone() |
|
|
|
|
|
|
|
|
edited_weights = original_weights+a1*1e6*young+a2*1e6*pointy+a3*1e6*wavy+a4*2e6*large |
|
|
|
|
|
generator = generator.manual_seed(seed) |
|
|
latents = torch.randn( |
|
|
(1, unet.in_channels, 512 // 8, 512 // 8), |
|
|
generator = generator, |
|
|
device = device |
|
|
).bfloat16() |
|
|
|
|
|
|
|
|
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
|
|
|
|
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] |
|
|
|
|
|
max_length = text_input.input_ids.shape[-1] |
|
|
uncond_input = tokenizer( |
|
|
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" |
|
|
) |
|
|
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] |
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
noise_scheduler.set_timesteps(ddim_steps) |
|
|
latents = latents * noise_scheduler.init_noise_sigma |
|
|
|
|
|
|
|
|
|
|
|
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): |
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) |
|
|
|
|
|
if t>start_noise: |
|
|
pass |
|
|
elif t<=start_noise: |
|
|
network.proj = torch.nn.Parameter(edited_weights) |
|
|
network.reset() |
|
|
|
|
|
|
|
|
with network: |
|
|
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample |
|
|
|
|
|
|
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
latents = 1 / 0.18215 * latents |
|
|
image = vae.decode(latents).sample |
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
|
|
|
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] |
|
|
|
|
|
image = Image.fromarray((image * 255).round().astype("uint8")) |
|
|
|
|
|
|
|
|
network.proj = torch.nn.Parameter(original_weights) |
|
|
network.reset() |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_then_run(): |
|
|
sample_model() |
|
|
prompt = "sks person" |
|
|
negative_prompt = "low quality, blurry, unfinished, cartoon" |
|
|
seed = 5 |
|
|
cfg = 3.0 |
|
|
steps = 50 |
|
|
image = inference( prompt, negative_prompt, cfg, steps, seed) |
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global young |
|
|
global pointy |
|
|
global wavy |
|
|
global large |
|
|
|
|
|
young = get_direction(df, "Young", pinverse, 1000, device) |
|
|
young = debias(young, "Male", df, pinverse, device) |
|
|
young = debias(young, "Pointy_Nose", df, pinverse, device) |
|
|
young = debias(young, "Wavy_Hair", df, pinverse, device) |
|
|
young = debias(young, "Chubby", df, pinverse, device) |
|
|
young_max = torch.max(proj@young[0]/(torch.norm(young))**2).item() |
|
|
young_min = torch.min(proj@young[0]/(torch.norm(young))**2).item() |
|
|
|
|
|
pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) |
|
|
pointy = debias(pointy, "Young", df, pinverse, device) |
|
|
pointy = debias(pointy, "Male", df, pinverse, device) |
|
|
pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) |
|
|
pointy = debias(pointy, "Chubby", df, pinverse, device) |
|
|
pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) |
|
|
pointy_max = torch.max(proj@pointy[0]/(torch.norm(pointy))**2).item() |
|
|
pointy_min = torch.min(proj@pointy[0]/(torch.norm(pointy))**2).item() |
|
|
|
|
|
|
|
|
wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) |
|
|
wavy = debias(wavy, "Young", df, pinverse, device) |
|
|
wavy = debias(wavy, "Male", df, pinverse, device) |
|
|
wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) |
|
|
wavy = debias(wavy, "Chubby", df, pinverse, device) |
|
|
wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) |
|
|
wavy_max = torch.max(proj@wavy[0]/(torch.norm(wavy))**2).item() |
|
|
wavy_min = torch.min(proj@wavy[0]/(torch.norm(wavy))**2).item() |
|
|
|
|
|
large = get_direction(df, "Chubby", pinverse, 1000, device) |
|
|
large = debias(large, "Male", df, pinverse, device) |
|
|
large = debias(large, "Young", df, pinverse, device) |
|
|
large = debias(large, "Pointy_Nose", df, pinverse, device) |
|
|
large = debias(large, "Wavy_Hair", df, pinverse, device) |
|
|
large_max = torch.max(proj@large[0]/(torch.norm(large))**2).item() |
|
|
large_min = torch.min(proj@large[0]/(torch.norm(large))**2).item() |
|
|
|
|
|
class CustomImageDataset(Dataset): |
|
|
def __init__(self, images, transform=None): |
|
|
self.images = images |
|
|
self.transform = transform |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image = self.images[idx] |
|
|
if self.transform: |
|
|
image = self.transform(image) |
|
|
return image |
|
|
|
|
|
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1): |
|
|
global unet |
|
|
del unet |
|
|
global network |
|
|
unet, _, _, _, _ = load_models(device) |
|
|
|
|
|
proj = torch.zeros(1,pcs).bfloat16().to(device) |
|
|
network = LoRAw2w( proj, mean, std, v[:, :pcs], |
|
|
unet, |
|
|
rank=1, |
|
|
multiplier=1.0, |
|
|
alpha=27.0, |
|
|
train_method="xattn-strict" |
|
|
).to(device, torch.bfloat16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) |
|
|
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1) |
|
|
|
|
|
if torch.sum(mask) == 0: |
|
|
mask = torch.ones((1,1,64,64)).to(device).bfloat16() |
|
|
|
|
|
|
|
|
|
|
|
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), |
|
|
transforms.RandomCrop(512), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5])]) |
|
|
|
|
|
|
|
|
train_dataset = CustomImageDataset(image, transform=image_transforms) |
|
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
|
|
|
|
|
|
unet.train() |
|
|
for epoch in tqdm.tqdm(range(epochs)): |
|
|
for batch in train_dataloader: |
|
|
|
|
|
batch = batch.to(device).bfloat16() |
|
|
latents = vae.encode(batch).latent_dist.sample() |
|
|
latents = latents*0.18215 |
|
|
noise = torch.randn_like(latents) |
|
|
bsz = latents.shape[0] |
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) |
|
|
timesteps = timesteps.long() |
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
|
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] |
|
|
|
|
|
|
|
|
with network: |
|
|
model_pred = unet(noisy_latents, timesteps, text_embeddings).sample |
|
|
loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") |
|
|
optim.zero_grad() |
|
|
loss.backward() |
|
|
optim.step() |
|
|
|
|
|
|
|
|
|
|
|
return network |
|
|
|
|
|
|
|
|
|
|
|
def run_inversion(dict, pcs, epochs, weight_decay,lr): |
|
|
global network |
|
|
init_image = dict["image"].convert("RGB").resize((512, 512)) |
|
|
mask = dict["mask"].convert("RGB").resize((512, 512)) |
|
|
network = invert([init_image], mask, pcs, epochs, weight_decay,lr) |
|
|
|
|
|
|
|
|
|
|
|
prompt = "sks person" |
|
|
negative_prompt = "low quality, blurry, unfinished, cartoon" |
|
|
seed = 5 |
|
|
cfg = 3.0 |
|
|
steps = 50 |
|
|
image = inference( prompt, negative_prompt, cfg, steps, seed) |
|
|
torch.save(network.proj, "model.pt" ) |
|
|
return image, "model.pt" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
intro = """ |
|
|
<div style="display: flex;align-items: center;justify-content: center"> |
|
|
<h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1> |
|
|
<h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models</h3> |
|
|
</div> |
|
|
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block"> |
|
|
<a href="https://snap-research.github.io/weights2weights/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">paper</a> |
|
|
| |
|
|
<a href="https://huggingface.co/spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style=" |
|
|
display: inline-block; |
|
|
"> |
|
|
<img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a> |
|
|
</p> |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
|
gr.HTML(intro) |
|
|
with gr.Tab("Sampling Models + Editing"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gallery1 = gr.Image(label="Identity from Sampled Model") |
|
|
sample = gr.Button("Sample New Model") |
|
|
gallery2 = gr.Image(label="Identity from Edited Model") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox(label="Prompt", |
|
|
info="Make sure to include 'sks person'" , |
|
|
placeholder="sks person", |
|
|
value="sks person") |
|
|
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon") |
|
|
with gr.Row(): |
|
|
a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) |
|
|
|
|
|
a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) |
|
|
with gr.Row(): |
|
|
a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) |
|
|
a4 = gr.Slider(label="- placeholder for some fourth attribute +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) |
|
|
|
|
|
|
|
|
with gr.Accordion("Advanced Options", open=False): |
|
|
with gr.Column(): |
|
|
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True) |
|
|
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) |
|
|
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) |
|
|
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) |
|
|
|
|
|
|
|
|
|
|
|
submit = gr.Button("Generate") |
|
|
|
|
|
sample.click(fn=sample_then_run, outputs=gallery1) |
|
|
|
|
|
submit.click(fn=edit_inference, |
|
|
inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], |
|
|
outputs=gallery2) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("Inversion"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask", |
|
|
height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6) |
|
|
|
|
|
|
|
|
lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True) |
|
|
weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True) |
|
|
pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True) |
|
|
epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True) |
|
|
|
|
|
invert_button = gr.Button("Invert") |
|
|
|
|
|
with gr.Column(): |
|
|
gallery = gr.Image(label="Sample from Inverted Model", height=512, width=512) |
|
|
prompt = gr.Textbox(label="Prompt", |
|
|
info="Make sure to include 'sks person'" , |
|
|
placeholder="sks person", |
|
|
value="sks person") |
|
|
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon") |
|
|
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True) |
|
|
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) |
|
|
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) |
|
|
submit = gr.Button("Generate") |
|
|
|
|
|
file_output = gr.File(label="Download Model", container=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
invert_button.click(fn=run_inversion, |
|
|
inputs=[input_image, pcs, epochs, weight_decay,lr], |
|
|
outputs = [gallery, file_output]) |
|
|
|
|
|
submit.click(fn=inference, |
|
|
inputs=[prompt, negative_prompt, cfg, steps, seed,], |
|
|
outputs=gallery) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.queue().launch(share=True) |
|
|
|
|
|
|