Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import sys
|
| 3 |
import os
|
|
@@ -9,8 +13,9 @@ import warnings
|
|
| 9 |
warnings.filterwarnings("ignore")
|
| 10 |
from PIL import Image
|
| 11 |
from utils import load_models, save_model_w2w, save_model_for_diffusers
|
| 12 |
-
from sampling import sample_weights
|
| 13 |
from editing import get_direction, debias
|
|
|
|
|
|
|
| 14 |
from huggingface_hub import snapshot_download
|
| 15 |
|
| 16 |
global device
|
|
@@ -20,11 +25,13 @@ global vae
|
|
| 20 |
global text_encoder
|
| 21 |
global tokenizer
|
| 22 |
global noise_scheduler
|
| 23 |
-
|
| 24 |
device = "cuda:0"
|
| 25 |
generator = torch.Generator(device=device)
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
| 28 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
| 29 |
|
| 30 |
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
|
|
@@ -36,7 +43,7 @@ weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
|
|
| 36 |
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
|
| 37 |
|
| 38 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
| 39 |
-
|
| 40 |
|
| 41 |
def sample_model():
|
| 42 |
global unet
|
|
@@ -47,6 +54,9 @@ def sample_model():
|
|
| 47 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
| 50 |
@torch.no_grad()
|
| 51 |
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
| 52 |
global device
|
|
@@ -94,7 +104,7 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
|
| 94 |
|
| 95 |
image = Image.fromarray((image * 255).round().astype("uint8"))
|
| 96 |
|
| 97 |
-
return
|
| 98 |
|
| 99 |
|
| 100 |
|
|
@@ -173,16 +183,13 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
|
|
| 173 |
network.proj = torch.nn.Parameter(original_weights)
|
| 174 |
network.reset()
|
| 175 |
|
| 176 |
-
return
|
| 177 |
|
| 178 |
|
| 179 |
|
| 180 |
|
| 181 |
def sample_then_run():
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
sample_model()
|
| 185 |
-
|
| 186 |
prompt = "sks person"
|
| 187 |
negative_prompt = "low quality, blurry, unfinished, cartoon"
|
| 188 |
seed = 5
|
|
@@ -192,6 +199,8 @@ def sample_then_run():
|
|
| 192 |
return image
|
| 193 |
|
| 194 |
|
|
|
|
|
|
|
| 195 |
#directions
|
| 196 |
global young
|
| 197 |
global pointy
|
|
@@ -233,6 +242,115 @@ large = debias(large, "Wavy_Hair", df, pinverse, device)
|
|
| 233 |
large_max = torch.max(proj@large[0]/(torch.norm(large))**2).item()
|
| 234 |
large_min = torch.min(proj@large[0]/(torch.norm(large))**2).item()
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
intro = """
|
| 238 |
<div style="display: flex;align-items: center;justify-content: center">
|
|
@@ -249,61 +367,97 @@ intro = """
|
|
| 249 |
</p>
|
| 250 |
"""
|
| 251 |
|
| 252 |
-
with gr.Blocks(css="style.css") as demo:
|
| 253 |
-
gr.HTML(intro)
|
| 254 |
-
with gr.Row():
|
| 255 |
-
with gr.Column():
|
| 256 |
-
gallery1 = gr.Gallery(label="Identity from Sampled Model")
|
| 257 |
-
sample = gr.Button("Sample New Model")
|
| 258 |
-
gallery2 = gr.Gallery(label="Identity from Edited Model")
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
with gr.Row():
|
| 262 |
-
with gr.Column():
|
| 263 |
-
prompt = gr.Textbox(label="Prompt",
|
| 264 |
-
info="Make sure to include 'sks person'" ,
|
| 265 |
-
placeholder="sks person",
|
| 266 |
-
value="sks person")
|
| 267 |
-
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
|
| 268 |
-
with gr.Row():
|
| 269 |
-
a1 = gr.Slider(label="+Young", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
| 270 |
-
a2 = gr.Slider(label="+Pointy Nose", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
| 271 |
-
with gr.Row():
|
| 272 |
-
a3 = gr.Slider(label="+Curly Hair", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
| 273 |
-
a4 = gr.Slider(label="+Large", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
with gr.Accordion("Advanced Options", open=False):
|
| 277 |
-
with gr.Column():
|
| 278 |
-
seed = gr.Number(value=5, label="Seed", interactive=True)
|
| 279 |
-
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
| 280 |
-
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
| 281 |
-
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
|
| 282 |
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
-
submit = gr.Button("Submit")
|
| 286 |
-
|
| 287 |
-
|
| 288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
|
|
|
|
|
|
| 302 |
|
|
|
|
|
|
|
|
|
|
| 303 |
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
-
|
| 306 |
|
| 307 |
|
|
|
|
|
|
|
| 308 |
|
|
|
|
| 309 |
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
import gradio as gr
|
| 6 |
import sys
|
| 7 |
import os
|
|
|
|
| 13 |
warnings.filterwarnings("ignore")
|
| 14 |
from PIL import Image
|
| 15 |
from utils import load_models, save_model_w2w, save_model_for_diffusers
|
|
|
|
| 16 |
from editing import get_direction, debias
|
| 17 |
+
from sampling import sample_weights
|
| 18 |
+
from lora_w2w import LoRAw2w
|
| 19 |
from huggingface_hub import snapshot_download
|
| 20 |
|
| 21 |
global device
|
|
|
|
| 25 |
global text_encoder
|
| 26 |
global tokenizer
|
| 27 |
global noise_scheduler
|
| 28 |
+
global network
|
| 29 |
device = "cuda:0"
|
| 30 |
generator = torch.Generator(device=device)
|
| 31 |
|
| 32 |
|
| 33 |
+
|
| 34 |
+
|
| 35 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
| 36 |
|
| 37 |
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
|
|
|
|
| 43 |
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
|
| 44 |
|
| 45 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
| 46 |
+
|
| 47 |
|
| 48 |
def sample_model():
|
| 49 |
global unet
|
|
|
|
| 54 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
| 55 |
|
| 56 |
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
@torch.no_grad()
|
| 61 |
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
| 62 |
global device
|
|
|
|
| 104 |
|
| 105 |
image = Image.fromarray((image * 255).round().astype("uint8"))
|
| 106 |
|
| 107 |
+
return image
|
| 108 |
|
| 109 |
|
| 110 |
|
|
|
|
| 183 |
network.proj = torch.nn.Parameter(original_weights)
|
| 184 |
network.reset()
|
| 185 |
|
| 186 |
+
return image
|
| 187 |
|
| 188 |
|
| 189 |
|
| 190 |
|
| 191 |
def sample_then_run():
|
| 192 |
+
sample_model()
|
|
|
|
|
|
|
|
|
|
| 193 |
prompt = "sks person"
|
| 194 |
negative_prompt = "low quality, blurry, unfinished, cartoon"
|
| 195 |
seed = 5
|
|
|
|
| 199 |
return image
|
| 200 |
|
| 201 |
|
| 202 |
+
|
| 203 |
+
|
| 204 |
#directions
|
| 205 |
global young
|
| 206 |
global pointy
|
|
|
|
| 242 |
large_max = torch.max(proj@large[0]/(torch.norm(large))**2).item()
|
| 243 |
large_min = torch.min(proj@large[0]/(torch.norm(large))**2).item()
|
| 244 |
|
| 245 |
+
class CustomImageDataset(Dataset):
|
| 246 |
+
def __init__(self, images, transform=None):
|
| 247 |
+
self.images = images
|
| 248 |
+
self.transform = transform
|
| 249 |
+
|
| 250 |
+
def __len__(self):
|
| 251 |
+
return len(self.images)
|
| 252 |
+
|
| 253 |
+
def __getitem__(self, idx):
|
| 254 |
+
image = self.images[idx]
|
| 255 |
+
if self.transform:
|
| 256 |
+
image = self.transform(image)
|
| 257 |
+
return image
|
| 258 |
+
|
| 259 |
+
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
| 260 |
+
global unet
|
| 261 |
+
del unet
|
| 262 |
+
global network
|
| 263 |
+
unet, _, _, _, _ = load_models(device)
|
| 264 |
+
|
| 265 |
+
proj = torch.zeros(1,pcs).bfloat16().to(device)
|
| 266 |
+
network = LoRAw2w( proj, mean, std, v[:, :pcs],
|
| 267 |
+
unet,
|
| 268 |
+
rank=1,
|
| 269 |
+
multiplier=1.0,
|
| 270 |
+
alpha=27.0,
|
| 271 |
+
train_method="xattn-strict"
|
| 272 |
+
).to(device, torch.bfloat16)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
### load mask
|
| 279 |
+
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
|
| 280 |
+
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
|
| 281 |
+
### check if an actual mask was draw, otherwise mask is just all ones
|
| 282 |
+
if torch.sum(mask) == 0:
|
| 283 |
+
mask = torch.ones((1,1,64,64)).to(device).bfloat16()
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
### single image dataset
|
| 287 |
+
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
|
| 288 |
+
transforms.RandomCrop(512),
|
| 289 |
+
transforms.ToTensor(),
|
| 290 |
+
transforms.Normalize([0.5], [0.5])])
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
train_dataset = CustomImageDataset(image, transform=image_transforms)
|
| 294 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
### optimizer
|
| 299 |
+
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
|
| 300 |
+
|
| 301 |
+
### training loop
|
| 302 |
+
unet.train()
|
| 303 |
+
for epoch in tqdm.tqdm(range(epochs)):
|
| 304 |
+
for batch in train_dataloader:
|
| 305 |
+
### prepare inputs
|
| 306 |
+
batch = batch.to(device).bfloat16()
|
| 307 |
+
latents = vae.encode(batch).latent_dist.sample()
|
| 308 |
+
latents = latents*0.18215
|
| 309 |
+
noise = torch.randn_like(latents)
|
| 310 |
+
bsz = latents.shape[0]
|
| 311 |
+
|
| 312 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
| 313 |
+
timesteps = timesteps.long()
|
| 314 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 315 |
+
text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 316 |
+
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
|
| 317 |
+
|
| 318 |
+
### loss + sgd step
|
| 319 |
+
with network:
|
| 320 |
+
model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
|
| 321 |
+
loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
|
| 322 |
+
optim.zero_grad()
|
| 323 |
+
loss.backward()
|
| 324 |
+
optim.step()
|
| 325 |
+
|
| 326 |
+
### return optimized network
|
| 327 |
+
|
| 328 |
+
return network
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
| 333 |
+
global network
|
| 334 |
+
init_image = dict["image"].convert("RGB").resize((512, 512))
|
| 335 |
+
mask = dict["mask"].convert("RGB").resize((512, 512))
|
| 336 |
+
network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
#sample an image
|
| 340 |
+
prompt = "sks person"
|
| 341 |
+
negative_prompt = "low quality, blurry, unfinished, cartoon"
|
| 342 |
+
seed = 5
|
| 343 |
+
cfg = 3.0
|
| 344 |
+
steps = 50
|
| 345 |
+
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
| 346 |
+
torch.save(network.proj, "model.pt" )
|
| 347 |
+
return image, "model.pt"
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
|
| 354 |
|
| 355 |
intro = """
|
| 356 |
<div style="display: flex;align-items: center;justify-content: center">
|
|
|
|
| 367 |
</p>
|
| 368 |
"""
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
|
| 372 |
+
with gr.Blocks(css="style.css") as demo:
|
| 373 |
+
gr.HTML(intro)
|
| 374 |
+
with gr.Tab("Sampling Models + Editing"):
|
| 375 |
+
with gr.Row():
|
| 376 |
+
with gr.Column():
|
| 377 |
+
gallery1 = gr.Image(label="Identity from Sampled Model")
|
| 378 |
+
sample = gr.Button("Sample New Model")
|
| 379 |
+
gallery2 = gr.Image(label="Identity from Edited Model")
|
| 380 |
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
+
with gr.Row():
|
| 383 |
+
with gr.Column():
|
| 384 |
+
prompt = gr.Textbox(label="Prompt",
|
| 385 |
+
info="Make sure to include 'sks person'" ,
|
| 386 |
+
placeholder="sks person",
|
| 387 |
+
value="sks person")
|
| 388 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
|
| 389 |
+
with gr.Row():
|
| 390 |
+
a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
| 391 |
+
|
| 392 |
+
a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
| 393 |
+
with gr.Row():
|
| 394 |
+
a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
| 395 |
+
a4 = gr.Slider(label="- placeholder for some fourth attribute +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
| 396 |
|
| 397 |
+
|
| 398 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 399 |
+
with gr.Column():
|
| 400 |
+
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
|
| 401 |
+
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
| 402 |
+
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
| 403 |
+
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
|
| 404 |
|
| 405 |
|
| 406 |
+
|
| 407 |
+
submit = gr.Button("Generate")
|
| 408 |
+
|
| 409 |
+
sample.click(fn=sample_then_run, outputs=gallery1)
|
| 410 |
|
| 411 |
+
submit.click(fn=edit_inference,
|
| 412 |
+
inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4],
|
| 413 |
+
outputs=gallery2)
|
| 414 |
+
|
| 415 |
|
| 416 |
+
|
| 417 |
+
with gr.Tab("Inversion"):
|
| 418 |
+
with gr.Row():
|
| 419 |
+
with gr.Column():
|
| 420 |
+
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
|
| 421 |
+
height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
|
| 422 |
|
| 423 |
+
|
| 424 |
+
lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
|
| 425 |
+
weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
|
| 426 |
+
pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
|
| 427 |
+
epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
|
| 428 |
|
| 429 |
+
invert_button = gr.Button("Invert")
|
| 430 |
+
|
| 431 |
+
with gr.Column():
|
| 432 |
+
gallery = gr.Image(label="Sample from Inverted Model", height=512, width=512)
|
| 433 |
+
prompt = gr.Textbox(label="Prompt",
|
| 434 |
+
info="Make sure to include 'sks person'" ,
|
| 435 |
+
placeholder="sks person",
|
| 436 |
+
value="sks person")
|
| 437 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
|
| 438 |
+
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
|
| 439 |
+
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
| 440 |
+
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
| 441 |
+
submit = gr.Button("Generate")
|
| 442 |
+
|
| 443 |
+
file_output = gr.File(label="Download Model", container=False)
|
| 444 |
|
| 445 |
+
|
| 446 |
+
|
| 447 |
|
| 448 |
+
invert_button.click(fn=run_inversion,
|
| 449 |
+
inputs=[input_image, pcs, epochs, weight_decay,lr],
|
| 450 |
+
outputs = [gallery, file_output])
|
| 451 |
|
| 452 |
+
submit.click(fn=inference,
|
| 453 |
+
inputs=[prompt, negative_prompt, cfg, steps, seed,],
|
| 454 |
+
outputs=gallery)
|
| 455 |
|
| 456 |
+
|
| 457 |
|
| 458 |
|
| 459 |
+
|
| 460 |
+
|
| 461 |
|
| 462 |
+
demo.queue().launch(share=True)
|
| 463 |
|