airplane194 commited on
Commit
81cec35
·
1 Parent(s): 8bb51aa
Files changed (1) hide show
  1. python.py +21 -2
python.py CHANGED
@@ -4,6 +4,8 @@ import sys
4
  from typing import Sequence, Mapping, Any, Union
5
  import torch
6
  import spaces
 
 
7
  # from comfy import model_management
8
  from nodes import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_1
9
  from comfy_extras.nodes_custom_sampler import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_2
@@ -23,6 +25,21 @@ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors
23
  hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders", token = token)
24
  hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders", token = token)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
28
  """Returns the value at the given index of a sequence or mapping.
@@ -266,5 +283,7 @@ def generate_image(prompt,
266
  # saveimage_9 = saveimage.save_images(
267
  # filename_prefix="image", images=get_value_at_index(vaedecode_8, 0)
268
  # )
269
- image = to_pil_image(get_value_at_index(vaedecode_8, 0).cpu())
270
- return image, seed
 
 
 
4
  from typing import Sequence, Mapping, Any, Union
5
  import torch
6
  import spaces
7
+ import numpy as np
8
+ from PIL import Image
9
  # from comfy import model_management
10
  from nodes import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_1
11
  from comfy_extras.nodes_custom_sampler import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_2
 
25
  hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders", token = token)
26
  hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders", token = token)
27
 
28
+ def preprocess_image_tensor(image):
29
+ # If image has a batch dimension (shape: [1, C, H, W]), remove it.
30
+ if image.ndim == 4 and image.shape[0] == 1:
31
+ image = image.squeeze(0)
32
+ # If image is in channels-first format (i.e. [C, H, W]) and has 3 or 4 channels,
33
+ # convert it to channels-last format ([H, W, C]).
34
+ if image.ndim == 3 and image.shape[0] in [1, 3, 4]:
35
+ image = image.permute(1, 2, 0)
36
+ # Ensure the image values are between 0 and 1. Then scale them to [0, 255].
37
+ image = image.detach().cpu().numpy()
38
+ image = np.clip(image, 0, 1) * 255
39
+ # Convert to unsigned 8-bit integer type.
40
+ image = image.astype(np.uint8)
41
+ return image
42
+
43
 
44
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
45
  """Returns the value at the given index of a sequence or mapping.
 
283
  # saveimage_9 = saveimage.save_images(
284
  # filename_prefix="image", images=get_value_at_index(vaedecode_8, 0)
285
  # )
286
+ image_tensor = get_value_at_index(vaedecode_8, 0)
287
+ preprocessed_image = preprocess_image_tensor(image_tensor)
288
+ pil_image = Image.fromarray(preprocessed_image)
289
+ return pil_image, seed