File size: 5,490 Bytes
af8c1fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#import spaces
import gradio as gr
import numpy as np
import random
import python
import torch
import os
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
from peft import PeftModel

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
token = os.getenv("HF_TKN")

# good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype, token=token).to(device)
# pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token=token).to(device)
torch.cuda.empty_cache()

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048  # not used anymore

# Bind the custom method
# pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
# python.model_loading()



def infer(prompt, seed=42, randomize_seed=True, aspect_ratio="4:3 landscape 1152x896", lora_weight="lora_weight_rank_32_alpha_32.safetensors",
          guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):

    # Randomize seed if requested
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    
    # Load the selected LoRA weight and fuse it
    lora_weight_path = os.path.join("loras", lora_weight)
    # pipe.load_lora_weights(weight_path)
    # pipe.fuse_lora()
    torch.cuda.empty_cache()
    image, seed = python.generate_image(
        prompt,
        guidance_scale,
        aspect_ratio,
        seed,
        num_inference_steps,
        lora_weight,
    )
    # Generate images
    # for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
    #         prompt=prompt,
    #         guidance_scale=guidance_scale,
    #         num_inference_steps=num_inference_steps,
    #         width=width,
    #         height=height,
    #         generator=generator,
    #         output_type="pil",
    #         good_vae=good_vae,
    #     ):
    #     out_img = img
    return image,seed

# Examples for the prompt
examples = [
    "Photo on a small glass panel. Color. A vintage Autochrome photograph, early 1900s aesthetic depicts four roses in a brown vase with dark background.",
    "Photo on a small glass panel. Color. A depiction of trees with orange leaves and a small path.",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# Text2Autochrome demo! 
        """)
        
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=5,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Result", show_label=False)
        
        with gr.Accordion("Advanced Settings", open=True):
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            # Dropdown for aspect ratio selection
            aspect_ratio = gr.Dropdown(
                label="Aspect Ratio",
                choices=["1:1 square 1024x1024", "3:4 portrait 896x1152", "5:8 portrait 832x1216", "9:16 portrait 768x1344", "4:3 landscape 1152x896", "3:2 landscape 1216x832", "16:9 landscape 1344x768"],
                value="4:3 landscape 1152x896",
                interactive=True,
            )
            
            # Dropdown for LoRA weight selection
            lora_weight = gr.Dropdown(
                label="LoRA Weight",
                choices=[
                    "lora_weight_rank_16_alpha_32_1.safetensors",
                    "lora_weight_rank_16_alpha_32_2.safetensors",
                    "lora_weight_rank_32_alpha_32.safetensors",
                    "lora_weight_rank_32_alpha_64.safetensors",
                ],
                value="lora_weight_rank_16_alpha_32_1.safetensors",
                interactive=True,
            )
            
            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=1,
                    maximum=25,
                    step=0.1,
                    value=8.5,
                )
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=100,
                    step=1,
                    value=50,
                )
        
        gr.Examples(
            examples=examples,
            fn=infer,
            inputs=[prompt],
            outputs=[result, seed],
            cache_examples=False
        )

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[prompt, seed, randomize_seed, aspect_ratio, lora_weight, guidance_scale, num_inference_steps],
        outputs=[result, seed]
    )

demo.launch()