Skip to content

Commit 3d085a2

Browse files
committed
addressed PR comments
1 parent 6d01ea0 commit 3d085a2

File tree

4 files changed

+14
-78
lines changed

4 files changed

+14
-78
lines changed

docs/source/en/api/models/controlnet_sana.md

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,12 @@ The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sa
2424
## Loading from the original format
2525
By default the [`SanaControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`]
2626
```py
27-
from diffusers import SanaControlNetModel, SanaControlNetPipeline
27+
from diffusers import SanaControlNetModel
2828
import torch
2929

3030
controlnet = SanaControlNetModel.from_pretrained(
3131
"ishan24/Sana_600M_1024px_ControlNet_diffusers",
3232
)
33-
pipe = SanaControlNetPipeline.from_pretrained(
34-
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
35-
controlnet=controlnet,
36-
)
37-
pipe.to('cuda')
3833
```
3934

4035
## SanaControlNetModel

docs/source/en/api/pipelines/controlnet_sana.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,9 @@ pipe = SanaControlNetPipeline.from_pretrained(
4242
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
4343
variant="fp16",
4444
controlnet=controlnet,
45-
torch_dtype=torch.float16,
45+
torch_dtype={'default': torch.bfloat16, 'transformer': torch.float16},
4646
)
47-
4847
pipe.to('cuda')
49-
pipe.vae.to(torch.bfloat16)
50-
pipe.text_encoder.to(torch.bfloat16)
5148

5249
cond_image = load_image(
5350
"https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"

src/diffusers/pipelines/sana/pipeline_sana_controlnet.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -367,13 +367,7 @@ def encode_prompt(
367367
prompt_embeds = prompt_embeds[0][:, select_index]
368368
prompt_attention_mask = prompt_attention_mask[:, select_index]
369369

370-
if self.transformer is not None:
371-
dtype = self.transformer.dtype
372-
elif self.text_encoder is not None:
373-
dtype = self.text_encoder.dtype
374-
else:
375-
dtype = None
376-
370+
dtype = self.text_encoder.dtype
377371
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
378372

379373
bs_embed, seq_len, _ = prompt_embeds.shape
@@ -406,6 +400,8 @@ def encode_prompt(
406400
negative_prompt_embeds = negative_prompt_embeds[0]
407401

408402
if do_classifier_free_guidance:
403+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
404+
409405
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
410406
seq_len = negative_prompt_embeds.shape[1]
411407

@@ -956,6 +952,7 @@ def __call__(
956952
height, width = control_image.shape[-2:]
957953

958954
control_image = self.vae.encode(control_image).latent
955+
control_image = control_image.to(self.vae.dtype)
959956
control_image = control_image * self.vae.config.scaling_factor
960957

961958
else:
@@ -992,12 +989,14 @@ def __call__(
992989
continue
993990

994991
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
995-
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
996992

997993
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
998-
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
994+
timestep = t.expand(latent_model_input.shape[0])
999995

1000996
# controlnet(s) inference
997+
latent_model_input = latent_model_input.to(dtype=self.controlnet.dtype)
998+
prompt_embeds = prompt_embeds.to(dtype=self.controlnet.dtype)
999+
control_image = control_image.to(dtype=self.controlnet.dtype)
10011000
controlnet_block_samples = self.controlnet(
10021001
latent_model_input,
10031002
encoder_hidden_states=prompt_embeds,
@@ -1010,6 +1009,9 @@ def __call__(
10101009
)[0]
10111010

10121011
# predict noise model_output
1012+
latent_model_input = latent_model_input.to(dtype=self.transformer.dtype)
1013+
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
1014+
controlnet_block_samples = controlnet_block_samples.to(dtype=self.transformer.dtype)
10131015
noise_pred = self.transformer(
10141016
latent_model_input,
10151017
encoder_hidden_states=prompt_embeds,

tests/pipelines/sana/test_sana_controlnet.py

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import gc
1615
import inspect
1716
import unittest
1817

@@ -27,12 +26,8 @@
2726
SanaControlNetPipeline,
2827
SanaTransformer2DModel,
2928
)
30-
from diffusers.utils import load_image
3129
from diffusers.utils.testing_utils import (
32-
backend_empty_cache,
3330
enable_full_determinism,
34-
require_torch_accelerator,
35-
slow,
3631
torch_device,
3732
)
3833

@@ -79,6 +74,7 @@ def get_dummy_components(self):
7974
sample_size=32,
8075
)
8176

77+
torch.manual_seed(0)
8278
transformer = SanaTransformer2DModel(
8379
patch_size=1,
8480
in_channels=4,
@@ -329,57 +325,3 @@ def test_inference_batch_single_identical(self):
329325
def test_float16_inference(self):
330326
# Requires higher tolerance as model seems very sensitive to dtype
331327
super().test_float16_inference(expected_max_diff=0.08)
332-
333-
334-
@slow
335-
@require_torch_accelerator
336-
class SanaPipelineIntegrationTests(unittest.TestCase):
337-
prompt = "A painting of a squirrel eating a burger."
338-
339-
def setUp(self):
340-
super().setUp()
341-
gc.collect()
342-
backend_empty_cache(torch_device)
343-
344-
def tearDown(self):
345-
super().tearDown()
346-
gc.collect()
347-
backend_empty_cache(torch_device)
348-
349-
def test_sana_1024(self):
350-
generator = torch.Generator("cpu").manual_seed(0)
351-
controlnet = SanaControlNetModel.from_pretrained(
352-
"ishan24/Sana_600M_1024px_ControlNet_diffusers", torch_dtype=torch.float16
353-
)
354-
355-
pipe = SanaControlNetPipeline.from_pretrained(
356-
"Efficient-Large-Model/Sana_600M_1024px_diffusers",
357-
variant="fp16",
358-
torch_dtype=torch.float16,
359-
controlnet=controlnet,
360-
)
361-
pipe.vae.to(torch.bfloat16)
362-
pipe.text_encoder.to(torch.bfloat16)
363-
pipe.enable_model_cpu_offload(device=torch_device)
364-
control_image = load_image(
365-
"https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
366-
)
367-
368-
image = pipe(
369-
prompt=self.prompt,
370-
height=1024,
371-
width=1024,
372-
generator=generator,
373-
num_inference_steps=20,
374-
output_type="np",
375-
control_image=control_image,
376-
).images[0]
377-
378-
image = image.flatten()
379-
output_slice = np.concatenate((image[:16], image[-16:]))
380-
381-
# fmt: off
382-
expected_slice = np.array([0.0427, 0.0789, 0.0662, 0.0464, 0.082, 0.0574, 0.0535, 0.0886, 0.0647, 0.0549, 0.0872, 0.0605, 0.0593, 0.0942, 0.0674, 0.0581, 0.0076, 0.0168, 0.0027, 0.0063, 0.0159, 0.0, 0.0071, 0.0198, 0.0034, 0.0105, 0.0212, 0.0, 0.0, 0.0166, 0.0042, 0.0125])
383-
# fmt: on
384-
385-
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-4))

0 commit comments

Comments
 (0)