import gradio as gr import torch import spaces from diffusers import FluxPipeline from safetensors.torch import load_file # Load the model pipe = FluxPipeline.from_pretrained( 'black-forest-labs/FLUX.1-dev', torch_dtype=torch.bfloat16, use_safetensors=True ).to('cuda') # Load SRPO weights from https://huggingface.co/tencent/SRPO state_dict = load_file("tencent/SRPO/diffusion_pytorch_model.safetensors") pipe.transformer.load_state_dict(state_dict) @spaces.GPU(duration=120) def generate_image( prompt, negative_prompt="", width=1024, height=1024, guidance_scale=3.5, num_inference_steps=50, seed=-1 ): if seed == -1: seed = torch.randint(0, 2**32, (1,)).item() generator = torch.Generator(device='cuda').manual_seed(seed) image = pipe( prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, guidance_scale=guidance_scale, height=height, width=width, num_inference_steps=num_inference_steps, max_sequence_length=512, generator=generator ).images[0] return image, seed with gr.Blocks(title="FLUX SRPO Text-to-Image") as demo: gr.Markdown("# FLUX with SRPO (Self-Regulating Preference Optimization)") gr.Markdown("Generate high-quality images using FLUX model enhanced with Tencent's SRPO technique") with gr.Row(): with gr.Column(scale=3): prompt = gr.Textbox( label="Prompt", placeholder="Describe the image you want to generate...", lines=3 ) negative_prompt = gr.Textbox( label="Negative Prompt (optional)", placeholder="What you don't want to see in the image...", lines=2 ) with gr.Row(): width = gr.Slider( minimum=256, maximum=2048, value=1024, step=64, label="Width" ) height = gr.Slider( minimum=256, maximum=2048, value=1024, step=64, label="Height" ) with gr.Row(): guidance_scale = gr.Slider( minimum=1.0, maximum=20.0, value=3.5, step=0.5, label="Guidance Scale" ) num_inference_steps = gr.Slider( minimum=10, maximum=100, value=50, step=5, label="Inference Steps" ) seed = gr.Number( label="Seed (-1 for random)", value=-1, precision=0 ) generate_btn = gr.Button("Generate Image", variant="primary", size="lg") with gr.Column(scale=4): output_image = gr.Image(label="Generated Image", type="pil") used_seed = gr.Number(label="Seed Used", precision=0) gr.Examples( examples=[ ["The Death of Ophelia by John Everett Millais, Pre-Raphaelite painting, Ophelia floating in a river surrounded by flowers, detailed natural elements, melancholic and tragic atmosphere"], ["A serene Japanese garden with cherry blossoms, koi pond, traditional wooden bridge, soft morning light, photorealistic"], ["Cyberpunk cityscape at night, neon lights, flying cars, rain-slicked streets, blade runner aesthetic, highly detailed"], ["Portrait of a majestic lion in golden hour light, detailed fur texture, intense gaze, African savanna background"], ["Abstract colorful explosion of paint in water, high speed photography, vibrant colors mixing, dramatic lighting"], ], inputs=prompt, label="Example Prompts" ) generate_btn.click( fn=generate_image, inputs=[prompt, negative_prompt, width, height, guidance_scale, num_inference_steps, seed], outputs=[output_image, used_seed] ) demo.launch()