Skip to content

Commit ef57544

Browse files
committed
fix trt support
1 parent 4343061 commit ef57544

File tree

13 files changed

+104
-104
lines changed

13 files changed

+104
-104
lines changed

README.md

Lines changed: 8 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,22 @@ This repo contains minimal inference code to run image generation & editing with
1010
```bash
1111
cd $HOME && git clone https://github.com/black-forest-labs/flux
1212
cd $HOME/flux
13-
14-
# Using pyvenv
1513
python3.10 -m venv .venv
1614
source .venv/bin/activate
1715
pip install -e ".[all]"
1816
```
1917

20-
## Local installation with TRT support
18+
### Local installation with TensorRT support
19+
20+
If you would like to install the repository with [TensorRT](https://github.com/NVIDIA/TensorRT) support, you currently need to install a PyTorch image from NVIDIA instead. First install [enroot](https://github.com/NVIDIA/enroot), next follow the steps below:
2121

2222
```bash
23-
docker pull nvcr.io/nvidia/pytorch:24.10-py3
2423
cd $HOME && git clone https://github.com/black-forest-labs/flux
25-
cd $HOME/flux
26-
docker run --rm -it --gpus all -v $PWD:/workspace/flux nvcr.io/nvidia/pytorch:24.10-py3 /bin/bash
27-
# inside container
28-
cd /workspace/flux
29-
pip install -e ".[all]"
30-
pip install -r trt_requirements.txt
24+
enroot import 'docker://$oauthtoken@nvcr.io#nvidia/pytorch:25.01-py3'
25+
enroot create -n pti2501 nvidia+pytorch+25.01-py3.sqsh
26+
enroot start --rw -m ${PWD}/flux:/workspace/flux -r pti2501
27+
cd flux
28+
pip install -e ".[tensorrt]" --extra-index-url https://pypi.nvidia.com
3129
```
3230

3331
### Models
@@ -55,57 +53,6 @@ We are offering an extensive suite of models. For more information about the inv
5553

5654
The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above.
5755

58-
We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo:
59-
60-
```bash
61-
python demo_gr.py --name flux-schnell --device cuda
62-
```
63-
64-
Options:
65-
66-
- `--name`: Choose the model to use (options: "flux-schnell", "flux-dev")
67-
- `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu")
68-
- `--offload`: Offload model to CPU when not in use
69-
- `--share`: Create a public link to your demo
70-
71-
To run the demo with the dev model and create a public link:
72-
73-
```bash
74-
python demo_gr.py --name flux-dev --share
75-
```
76-
77-
## Diffusers integration
78-
79-
`FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it:
80-
81-
```shell
82-
pip install git+https://github.com/huggingface/diffusers.git
83-
```
84-
85-
Then you can use `FluxPipeline` to run the model
86-
87-
```python
88-
import torch
89-
from diffusers import FluxPipeline
90-
91-
model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev`
92-
93-
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
94-
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
95-
96-
prompt = "A cat holding a sign that says hello world"
97-
seed = 42
98-
image = pipe(
99-
prompt,
100-
output_type="pil",
101-
num_inference_steps=4, #use a larger number if you are using [dev]
102-
generator=torch.Generator("cpu").manual_seed(seed)
103-
).images[0]
104-
image.save("flux-schnell.png")
105-
```
106-
107-
To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation
108-
10956
## API usage
11057

11158
Our API offers access to our models. It is documented here:

demo_gr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
NSFW_THRESHOLD = 0.85
1717

18+
1819
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
1920
t5 = load_t5(device, max_length=256 if is_schnell else 512)
2021
clip = load_clip(device)
@@ -23,6 +24,7 @@ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool)
2324
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
2425
return model, ae, t5, clip, nsfw_classifier
2526

27+
2628
class FluxGenerator:
2729
def __init__(self, model_name: str, device: str, offload: bool):
2830
self.device = torch.device(device)

docs/structural-conditioning.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ python -m src.flux.cli_control --loop --name <name>
3939

4040
where `name` is one of `flux-dev-canny`, `flux-dev-depth`, `flux-dev-canny-lora`, or `flux-dev-depth-lora`.
4141

42+
### TRT engine infernece
43+
44+
You may also download ONNX export of [FLUX.1 Depth \[dev\]](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-onnx) and [FLUX.1 Canny \[dev\]](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-onnx). We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md).
45+
46+
```bash
47+
TRT_ENGINE_DIR=<your_trt_engine_will_be_saved_here> ONNX_DIR=<path_of_downloaded_onnx_export> python src/flux/cli.py "<prompt>" --img_cond_path="assets/robot.webp" --trt --static_shape=False --name=<name> --trt_transformer_precision <precision>
48+
```
49+
where `<precision>` is either bf16, fp8, or fp4. For fp4, you need a NVIDIA GPU based on the [Blackwell Architecture](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/).
50+
4251
## Diffusers usage
4352

4453
Flux Control (including the LoRAs) is also compatible with the `diffusers` Python library. Check out the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.

docs/text-to-image.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ python -m flux --name <name> \
3535
--prompt "<prompt>"
3636
```
3737

38+
### TRT engine infernece
39+
40+
You may also download ONNX export of [FLUX.1 \[dev\]](https://huggingface.co/black-forest-labs/FLUX.1-dev-onnx) and [FLUX.1 \[schnell\]](https://huggingface.co/black-forest-labs/FLUX.1-schnell-onnx). We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md).
41+
42+
```bash
43+
TRT_ENGINE_DIR=<your_trt_engine_will_be_saved_here> ONNX_DIR=<path_of_downloaded_onnx_export> python src/flux/cli.py "<prompt>" --trt --static_shape=False --name=<name> --trt_transformer_precision <precision>
44+
```
45+
where `<precision>` is either bf16, fp8, or fp4. For fp4, you need a NVIDIA GPU based on the [Blackwell Architecture](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/).
46+
47+
### Streamlit and Gradio
48+
3849
We also provide a streamlit demo that does both text-to-image and image-to-image. The demo can be run via
3950

4051
```bash

pyproject.toml

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ requires-python = ">=3.10"
99
license = { file = "LICENSE.md" }
1010
dynamic = ["version"]
1111
dependencies = [
12-
"torch == 2.5.1",
13-
"torchvision",
1412
"einops",
1513
"fire >= 0.6.0",
1614
"huggingface-hub",
@@ -25,6 +23,10 @@ dependencies = [
2523
]
2624

2725
[project.optional-dependencies]
26+
torch = [
27+
"torch == 2.5.1",
28+
"torchvision",
29+
]
2830
streamlit = [
2931
"streamlit",
3032
"streamlit-drawable-canvas",
@@ -33,9 +35,22 @@ streamlit = [
3335
gradio = [
3436
"gradio",
3537
]
38+
tensorrt = [
39+
"tensorrt-cu12 == 10.8.0.43",
40+
"colored",
41+
"cuda-python",
42+
"diffusers",
43+
"nvidia-modelopt[torch,onnx] ~= 0.19.0",
44+
"opencv-python ~= 4.8.0.74",
45+
"onnx ~= 1.17.0",
46+
"onnxruntime ~= 1.19.2",
47+
"onnx-graphsurgeon",
48+
"polygraphy ~= 0.49.9",
49+
]
3650
all = [
37-
"flux[streamlit]",
3851
"flux[gradio]",
52+
"flux[streamlit]",
53+
"flux[torch]",
3954
]
4055

4156
[project.scripts]

src/flux/cli.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ class SamplingOptions:
2727

2828

2929
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
30-
user_question = (
31-
"Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
32-
)
30+
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
3331
usage = (
3432
"Usage: Either write your prompt directly, leave this field empty "
3533
"to repeat the prompt or write a command starting with a slash:\n"
@@ -113,6 +111,7 @@ def main(
113111
output_dir: str = "output",
114112
add_sampling_metadata: bool = True,
115113
trt: bool = False,
114+
trt_transformer_precision: str = "bf16",
116115
**kwargs: dict | None,
117116
):
118117
"""
@@ -135,6 +134,19 @@ def main(
135134
trt: use TensorRT backend for optimized inference
136135
kwargs: additional arguments for TensorRT support
137136
"""
137+
138+
prompt = prompt.split("|")
139+
if len(prompt) == 1:
140+
prompt = prompt[0]
141+
additional_prompts = None
142+
else:
143+
additional_prompts = prompt[1:]
144+
prompt = prompt[0]
145+
146+
assert not (
147+
(additional_prompts is not None) and loop
148+
), "Do not provide additional prompts and set loop to True"
149+
138150
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
139151

140152
if name not in configs:
@@ -193,6 +205,7 @@ def main(
193205
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
194206
opt_image_height=height,
195207
opt_image_width=width,
208+
transformer_precision=trt_transformer_precision,
196209
)
197210

198211
torch.cuda.synchronize()
@@ -251,9 +264,7 @@ def main(
251264
torch.cuda.empty_cache()
252265
t5, clip = t5.to(torch_device), clip.to(torch_device)
253266
inp = prepare(t5, clip, x, prompt=opts.prompt)
254-
timesteps = get_schedule(
255-
opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")
256-
)
267+
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
257268

258269
# offload TEs to CPU, load model to gpu
259270
if offload:
@@ -287,12 +298,16 @@ def main(
287298
if loop:
288299
print("-" * 80)
289300
opts = parse_prompt(opts)
301+
elif additional_prompts:
302+
next_prompt = additional_prompts.pop(0)
303+
opts.prompt = next_prompt
290304
else:
291305
opts = None
292306

293307
if trt:
294308
trt_ctx_manager.stop_runtime()
295309

310+
296311
def app():
297312
Fire(main)
298313

src/flux/cli_control.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def main(
177177
img_cond_path: str = "assets/robot.webp",
178178
lora_scale: float | None = 0.85,
179179
trt: bool = False,
180+
trt_transformer_precision: str = "bf16",
180181
**kwargs: dict | None,
181182
):
182183
"""
@@ -272,6 +273,7 @@ def main(
272273
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
273274
opt_image_height=height,
274275
opt_image_width=width,
276+
transformer_precision=trt_transformer_precision,
275277
)
276278
torch.cuda.synchronize()
277279

src/flux/trt/engine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from flux.trt.engine.clip_engine import CLIPEngine
1919
from flux.trt.engine.t5_engine import T5Engine
2020
from flux.trt.engine.transformer_engine import TransformerEngine
21-
from flux.trt.engine.vae_engine import VAEEngine, VAEDecoder, VAEEncoder
21+
from flux.trt.engine.vae_engine import VAEDecoder, VAEEncoder, VAEEngine
2222

2323
__all__ = [
2424
"BaseEngine",

src/flux/trt/engine/vae_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
# limitations under the License.
1616

1717
import torch
18+
from cuda import cudart
1819

1920
from flux.trt.engine.base_engine import BaseEngine, Engine
2021
from flux.trt.mixin import VAEMixin
21-
from cuda import cudart
2222

2323

2424
class VAEDecoder(VAEMixin, Engine):
@@ -162,7 +162,6 @@ def load(self):
162162
if self.encoder is not None:
163163
self.encoder.load()
164164

165-
166165
def activate(
167166
self,
168167
device: str,

src/flux/trt/exporter/vae_exporter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import torch
1817
from math import ceil
18+
19+
import torch
20+
1921
from flux.modules.autoencoder import Decoder, Encoder
2022
from flux.trt.exporter.base_exporter import BaseExporter
2123
from flux.trt.mixin import VAEMixin

0 commit comments

Comments
 (0)