22import time
33import uuid
44
5- import torch
65import gradio as gr
76import numpy as np
7+ import torch
88from einops import rearrange
9- from PIL import Image , ExifTags
9+ from PIL import ExifTags , Image
1010from transformers import pipeline
1111
1212from flux .cli import SamplingOptions
1515
1616NSFW_THRESHOLD = 0.85
1717
18+
1819def get_models (name : str , device : torch .device , offload : bool , is_schnell : bool ):
1920 t5 = load_t5 (device , max_length = 256 if is_schnell else 512 )
2021 clip = load_clip (device )
@@ -23,6 +24,7 @@ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool)
2324 nsfw_classifier = pipeline ("image-classification" , model = "Falconsai/nsfw_image_detection" , device = device )
2425 return model , ae , t5 , clip , nsfw_classifier
2526
27+
2628class FluxGenerator :
2729 def __init__ (self , model_name : str , device : str , offload : bool ):
2830 self .device = torch .device (device )
@@ -70,7 +72,7 @@ def generate_image(
7072 if init_image is not None :
7173 if isinstance (init_image , np .ndarray ):
7274 init_image = torch .from_numpy (init_image ).permute (2 , 0 , 1 ).float () / 255.0
73- init_image = init_image .unsqueeze (0 )
75+ init_image = init_image .unsqueeze (0 )
7476 init_image = init_image .to (self .device )
7577 init_image = torch .nn .functional .interpolate (init_image , (opts .height , opts .width ))
7678 if self .offload :
@@ -151,37 +153,49 @@ def generate_image(
151153 exif_data [ExifTags .Base .Model ] = self .model_name
152154 if add_sampling_metadata :
153155 exif_data [ExifTags .Base .ImageDescription ] = prompt
154-
156+
155157 img .save (filename , format = "jpeg" , exif = exif_data , quality = 95 , subsampling = 0 )
156158
157159 return img , str (opts .seed ), filename , None
158160 else :
159161 return None , str (opts .seed ), None , "Your generated image may contain NSFW content."
160162
161- def create_demo (model_name : str , device : str = "cuda" if torch .cuda .is_available () else "cpu" , offload : bool = False ):
163+
164+ def create_demo (
165+ model_name : str , device : str = "cuda" if torch .cuda .is_available () else "cpu" , offload : bool = False
166+ ):
162167 generator = FluxGenerator (model_name , device , offload )
163168 is_schnell = model_name == "flux-schnell"
164169
165170 with gr .Blocks () as demo :
166171 gr .Markdown (f"# Flux Image Generation Demo - Model: { model_name } " )
167-
172+
168173 with gr .Row ():
169174 with gr .Column ():
170- prompt = gr .Textbox (label = "Prompt" , value = "a photo of a forest with mist swirling around the tree trunks. The word \" FLUX\" is painted over it in big, red brush strokes with visible texture" )
175+ prompt = gr .Textbox (
176+ label = "Prompt" ,
177+ value = 'a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture' ,
178+ )
171179 do_img2img = gr .Checkbox (label = "Image to Image" , value = False , interactive = not is_schnell )
172180 init_image = gr .Image (label = "Input Image" , visible = False )
173- image2image_strength = gr .Slider (0.0 , 1.0 , 0.8 , step = 0.1 , label = "Noising strength" , visible = False )
174-
181+ image2image_strength = gr .Slider (
182+ 0.0 , 1.0 , 0.8 , step = 0.1 , label = "Noising strength" , visible = False
183+ )
184+
175185 with gr .Accordion ("Advanced Options" , open = False ):
176186 width = gr .Slider (128 , 8192 , 1360 , step = 16 , label = "Width" )
177187 height = gr .Slider (128 , 8192 , 768 , step = 16 , label = "Height" )
178188 num_steps = gr .Slider (1 , 50 , 4 if is_schnell else 50 , step = 1 , label = "Number of steps" )
179- guidance = gr .Slider (1.0 , 10.0 , 3.5 , step = 0.1 , label = "Guidance" , interactive = not is_schnell )
189+ guidance = gr .Slider (
190+ 1.0 , 10.0 , 3.5 , step = 0.1 , label = "Guidance" , interactive = not is_schnell
191+ )
180192 seed = gr .Textbox (- 1 , label = "Seed (-1 for random)" )
181- add_sampling_metadata = gr .Checkbox (label = "Add sampling parameters to metadata?" , value = True )
182-
193+ add_sampling_metadata = gr .Checkbox (
194+ label = "Add sampling parameters to metadata?" , value = True
195+ )
196+
183197 generate_btn = gr .Button ("Generate" )
184-
198+
185199 with gr .Column ():
186200 output_image = gr .Image (label = "Generated Image" )
187201 seed_output = gr .Number (label = "Used Seed" )
@@ -198,17 +212,33 @@ def update_img2img(do_img2img):
198212
199213 generate_btn .click (
200214 fn = generator .generate_image ,
201- inputs = [width , height , num_steps , guidance , seed , prompt , init_image , image2image_strength , add_sampling_metadata ],
215+ inputs = [
216+ width ,
217+ height ,
218+ num_steps ,
219+ guidance ,
220+ seed ,
221+ prompt ,
222+ init_image ,
223+ image2image_strength ,
224+ add_sampling_metadata ,
225+ ],
202226 outputs = [output_image , seed_output , download_btn , warning_text ],
203227 )
204228
205229 return demo
206230
231+
207232if __name__ == "__main__" :
208233 import argparse
234+
209235 parser = argparse .ArgumentParser (description = "Flux" )
210- parser .add_argument ("--name" , type = str , default = "flux-schnell" , choices = list (configs .keys ()), help = "Model name" )
211- parser .add_argument ("--device" , type = str , default = "cuda" if torch .cuda .is_available () else "cpu" , help = "Device to use" )
236+ parser .add_argument (
237+ "--name" , type = str , default = "flux-schnell" , choices = list (configs .keys ()), help = "Model name"
238+ )
239+ parser .add_argument (
240+ "--device" , type = str , default = "cuda" if torch .cuda .is_available () else "cpu" , help = "Device to use"
241+ )
212242 parser .add_argument ("--offload" , action = "store_true" , help = "Offload model to CPU when not in use" )
213243 parser .add_argument ("--share" , action = "store_true" , help = "Create a public link to your demo" )
214244 args = parser .parse_args ()
0 commit comments