RishubhPar commited on
Commit
f7a4a99
·
verified ·
1 Parent(s): e90aefb

small changes

Browse files
Files changed (1) hide show
  1. app.py +109 -75
app.py CHANGED
@@ -3,6 +3,7 @@ import gc
3
  from typing import List, Tuple, Dict
4
  import json
5
  import spaces
 
6
 
7
  import torch
8
  import gradio as gr
@@ -26,80 +27,103 @@ if HF_TOKEN:
26
  # -----------------------------
27
  # Avoid meta-tensor init from environment leftovers
28
  os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
29
-
30
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
- print("Using device:", DEVICE)
32
- torch.backends.cudnn.benchmark = True
33
-
34
  PIPELINE=None
35
 
36
  # -----------------------------
37
  # Model / pipeline loading
38
  # -----------------------------
39
- @torch.no_grad()
40
- def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
41
- global PIPELINE, DEVICE
42
-
43
- pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
44
-
45
- n_slider_layers = 4
46
- slider_projector_out_dim = 6144
47
- trained_models_path = "./model_weights/"
48
- is_clip_input = True
49
-
50
- # Load transformer fully on CPU; avoid meta tensors
51
- transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained(
52
- pretrained,
53
- subfolder="transformer",
54
- device_map=None,
55
- low_cpu_mem_usage=False,
56
- token=HF_TOKEN,
57
- )
58
- weight_dtype = transformer.dtype # keep checkpoint dtype
59
 
60
- # Slider projector
61
- if is_clip_input:
62
- slider_projector = SliderProjector(
63
- out_dim=slider_projector_out_dim, pe_dim=2, n_layers=n_slider_layers, is_clip_input=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
- else:
66
- slider_projector = SliderProjector_wo_clip(
67
- out_dim=slider_projector_out_dim, pe_dim=2, n_layers=n_slider_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )
 
69
 
70
- # putting both the models to infer
71
- transformer.eval()
72
- slider_projector.eval()
73
-
74
- # Load projector weights on CPU
75
- slider_projector_path = os.path.join(trained_models_path, "slider_projector.pth")
76
- state_dict = torch.load(slider_projector_path, map_location='cpu')
77
- print("state_dict keys: {}".format(state_dict.keys()))
78
-
79
- slider_projector.load_state_dict(state_dict)
80
- print(f"loaded slider_projector from {slider_projector_path}")
81
- # ------------------------------- --------------------- --------------------------- #
82
-
83
- # Build full pipeline on CPU; no device_map sharding
84
- PIPELINE = FluxKontextSliderPipeline.from_pretrained(
85
- pretrained,
86
- transformer=transformer,
87
- slider_projector=slider_projector,
88
- torch_dtype=weight_dtype,
89
- device_map=None,
90
- low_cpu_mem_usage=False,
91
- )
92
-
93
- print("loading the pipeline lora weights from: {}".format(trained_models_path))
94
 
95
- PIPELINE.load_lora_weights(trained_models_path)
96
- print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
97
- PIPELINE.to(DEVICE)
 
 
 
 
98
 
99
- # Initializing the pipeline with gpu
100
- print("INIT pipeline with the gpu")
101
- load_pipeline_single_gpu()
102
- print(f"[init] Pipeline loaded on {DEVICE}")
103
 
104
 
105
  # -----------------------------
@@ -287,23 +311,25 @@ def resize_image(img: Image.Image, target: int = 512) -> Image.Image:
287
  img = img.resize((new_w, new_h), resample)
288
  return img
289
 
290
- @spaces.GPU
291
- def _encode_prompt(prompt: str):
292
- with torch.no_grad():
293
- pe, ppe, _ = PIPELINE.encode_prompt(prompt, prompt_2=prompt)
294
- return pe, ppe
295
-
296
-
297
  # -----------------------------
298
  # Inference functions
299
  # -----------------------------
300
- @spaces.GPU(duration=300)
301
  @torch.no_grad()
302
- def generate_image_stack_edits(text_prompt, n_edits, input_image, progress=gr.Progress(track_tqdm=True)):
303
  """
304
  Compute n_edits images on a single GPU for slider values in (0,1],
305
  return (list_of_images, first_image) so the UI shows immediately.
306
  """
 
 
 
 
 
 
 
 
 
307
  if not input_image or not text_prompt or text_prompt.startswith("Please select"):
308
  return [], None
309
 
@@ -312,7 +338,7 @@ def generate_image_stack_edits(text_prompt, n_edits, input_image, progress=gr.Pr
312
  slider_values = [(i + 1) / float(n) for i in range(n)] # (0,1] inclusive
313
 
314
  img = resize_image(input_image, 512)
315
- pe, ppe = _encode_prompt(text_prompt)
316
 
317
  results: List[Image.Image] = []
318
  gen_base = 64 # deterministic seed base
@@ -350,14 +376,15 @@ def generate_image_stack_edits(text_prompt, n_edits, input_image, progress=gr.Pr
350
  first = results[0] if results else None
351
  return results, first
352
 
353
- @spaces.GPU(duration=80)
354
- def generate_single_image(text_prompt, slider_value, input_image, progress=gr.Progress(track_tqdm=True)):
355
  if not input_image or not text_prompt or text_prompt.startswith("Please select"):
356
  return None
357
 
358
  img = resize_image(input_image, 512)
359
  sv = float(slider_value)
360
  pe, ppe = _encode_prompt(text_prompt)
 
361
 
362
  gen = torch.Generator(device=DEVICE if DEVICE != "cpu" else "cpu").manual_seed(64)
363
  with torch.no_grad():
@@ -492,7 +519,14 @@ def process_user_upload(uploaded_image, user_prompt, n_edits_val):
492
 
493
  return processed_image, generated_list, first_result, slider_update
494
 
 
 
 
 
 
495
  with gr.Blocks() as demo:
 
 
496
  gr.Markdown("# Kontinuous Kontext - Continuous Strength Control for Instruction-based Image Editing")
497
 
498
  # Add description section
 
3
  from typing import List, Tuple, Dict
4
  import json
5
  import spaces
6
+ import traceback
7
 
8
  import torch
9
  import gradio as gr
 
27
  # -----------------------------
28
  # Avoid meta-tensor init from environment leftovers
29
  os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
 
 
 
 
 
30
  PIPELINE=None
31
 
32
  # -----------------------------
33
  # Model / pipeline loading
34
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def _log(msg): print(msg, flush=True)
37
+
38
+ def load_pipeline_single_gpu():
39
+ global PIPELINE
40
+ if PIPELINE is not None:
41
+ _log("[worker] PIPELINE already initialized; skipping.")
42
+ return "warm"
43
+
44
+ try:
45
+ os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
46
+ token = os.environ.get("HF_TOKEN")
47
+ cuda_ok = torch.cuda.is_available()
48
+ _log(f"[worker] cuda available: {cuda_ok}")
49
+ if cuda_ok:
50
+ torch.backends.cudnn.benchmark = True
51
+
52
+ # ---------- config ----------
53
+ pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
54
+ trained_models_path = "./model_weights/"
55
+ projector_path = os.path.join(trained_models_path, "slider_projector.pth")
56
+ offload_dir = "/tmp/offload"; os.makedirs(offload_dir, exist_ok=True)
57
+
58
+ if not os.path.isdir(trained_models_path):
59
+ return f"error: missing dir {trained_models_path}"
60
+ if not os.path.isfile(projector_path):
61
+ return f"error: missing projector weights at {projector_path}"
62
+
63
+ # dtype selection to cut memory
64
+ if cuda_ok and torch.cuda.get_device_capability(0)[0] >= 8:
65
+ dtype = torch.bfloat16
66
+ elif cuda_ok:
67
+ dtype = torch.float16
68
+ else:
69
+ dtype = torch.float32
70
+
71
+ max_memory = {"cuda": "80GiB", "cpu": "60GiB"} # tune if needed
72
+
73
+ _log("[worker] loading transformer (sharded/offloaded)…")
74
+ transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained(
75
+ pretrained,
76
+ subfolder="transformer",
77
+ token=token,
78
+ trust_remote_code=True,
79
+ torch_dtype=dtype,
80
+ low_cpu_mem_usage=True,
81
+ # device_map="balanced_low_0",
82
+ offload_folder=offload_dir,
83
+ offload_state_dict=True,
84
+ # max_memory=max_memory,
85
  )
86
+ weight_dtype = transformer.dtype
87
+ _log(f"[worker] transformer loaded, dtype={weight_dtype}")
88
+
89
+ _log("[worker] building slider projector…")
90
+ slider_projector = SliderProjector(out_dim=6144, pe_dim=2, n_layers=4, is_clip_input=True)
91
+ slider_projector.eval()
92
+ _log("[worker] loading projector weights…")
93
+ state_dict = torch.load(projector_path, map_location="cpu", weights_only=True)
94
+ slider_projector.load_state_dict(state_dict, strict=True)
95
+
96
+ _log("[worker] assembling pipeline (sharded/offloaded)…")
97
+ pipe = FluxKontextSliderPipeline.from_pretrained(
98
+ pretrained,
99
+ token=token,
100
+ trust_remote_code=True,
101
+ transformer=transformer,
102
+ slider_projector=slider_projector,
103
+ torch_dtype=weight_dtype,
104
+ low_cpu_mem_usage=True,
105
+ # device_map="balanced_low_0",
106
+ offload_folder=offload_dir,
107
+ offload_state_dict=True,
108
+ # max_memory=max_memory,
109
  )
110
+ _log("[worker] pipeline assembled.")
111
 
112
+ _log(f"[worker] loading LoRA from: {trained_models_path}")
113
+ pipe.load_lora_weights(trained_models_path)
114
+ _log("[worker] LoRA loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ # DO NOT pipe.to("cuda") here; keep auto device_map to avoid OOM
117
+ PIPELINE = pipe
118
+ if cuda_ok:
119
+ free, total = torch.cuda.mem_get_info()
120
+ _log(f"[worker] VRAM free/total: {free/1e9:.2f}/{total/1e9:.2f} GB")
121
+ _log("[worker] PIPELINE ready.")
122
+ return "ok"
123
 
124
+ except Exception:
125
+ _log("[worker] init exception:\n" + traceback.format_exc())
126
+ return "error"
 
127
 
128
 
129
  # -----------------------------
 
311
  img = img.resize((new_w, new_h), resample)
312
  return img
313
 
 
 
 
 
 
 
 
314
  # -----------------------------
315
  # Inference functions
316
  # -----------------------------
317
+ @spaces.GPU
318
  @torch.no_grad()
319
+ def generate_image_stack_edits(text_prompt, n_edits, input_image):
320
  """
321
  Compute n_edits images on a single GPU for slider values in (0,1],
322
  return (list_of_images, first_image) so the UI shows immediately.
323
  """
324
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
325
+
326
+ # if pipeline is null will initialize it simply.
327
+ global PIPELINE
328
+ if PIPELINE is None:
329
+ status = load_pipeline_single_gpu()
330
+
331
+ print("loaded pipeline status: {}".format(status))
332
+
333
  if not input_image or not text_prompt or text_prompt.startswith("Please select"):
334
  return [], None
335
 
 
338
  slider_values = [(i + 1) / float(n) for i in range(n)] # (0,1] inclusive
339
 
340
  img = resize_image(input_image, 512)
341
+ pe, ppe, _ = PIPELINE.encode_prompt(prompt=text_prompt, prompt_2=text_prompt)
342
 
343
  results: List[Image.Image] = []
344
  gen_base = 64 # deterministic seed base
 
376
  first = results[0] if results else None
377
  return results, first
378
 
379
+ @spaces.GPU
380
+ def generate_single_image(text_prompt, slider_value, input_image):
381
  if not input_image or not text_prompt or text_prompt.startswith("Please select"):
382
  return None
383
 
384
  img = resize_image(input_image, 512)
385
  sv = float(slider_value)
386
  pe, ppe = _encode_prompt(text_prompt)
387
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
388
 
389
  gen = torch.Generator(device=DEVICE if DEVICE != "cpu" else "cpu").manual_seed(64)
390
  with torch.no_grad():
 
519
 
520
  return processed_image, generated_list, first_result, slider_update
521
 
522
+
523
+ @spaces.GPU
524
+ def gpu_warmup():
525
+ return load_pipeline_single_gpu()
526
+
527
  with gr.Blocks() as demo:
528
+ # warming up the demo for the first run
529
+ demo.load(gpu_warmup)
530
  gr.Markdown("# Kontinuous Kontext - Continuous Strength Control for Instruction-based Image Editing")
531
 
532
  # Add description section