Skip to content

Commit 3bc4cd9

Browse files
committed
enforce safety checker; use dummy checker in fast tests
1 parent fd837a8 commit 3bc4cd9

File tree

6 files changed

+210
-39
lines changed

6 files changed

+210
-39
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ class CosmosPipeline(DiffusionPipeline):
144144

145145
model_cpu_offload_seq = "text_encoder->transformer->vae"
146146
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
147-
_optional_components = ["safety_checker"]
148147

149148
def __init__(
150149
self,
@@ -153,19 +152,12 @@ def __init__(
153152
transformer: CosmosTransformer3DModel,
154153
vae: AutoencoderKLCosmos,
155154
scheduler: EDMEulerScheduler,
156-
safety_checker: CosmosSafetyChecker = None,
157-
requires_safety_checker: bool = True,
155+
safety_checker: CosmosSafetyChecker,
158156
):
159157
super().__init__()
160158

161-
if requires_safety_checker and safety_checker is None:
162-
safety_checker = CosmosSafetyChecker()
163159
if safety_checker is None:
164-
logger.warning(
165-
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. This "
166-
f"is in violation of the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
167-
f"Please ensure that you are compliant with the license agreement."
168-
)
160+
safety_checker = CosmosSafetyChecker()
169161

170162
self.register_modules(
171163
vae=vae,
@@ -175,7 +167,6 @@ def __init__(
175167
scheduler=scheduler,
176168
safety_checker=safety_checker,
177169
)
178-
self.register_to_config(requires_safety_checker=requires_safety_checker)
179170

180171
self.vae_scale_factor_temporal = (
181172
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
@@ -476,6 +467,13 @@ def __call__(
476467
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
477468
"""
478469

470+
if self.safety_checker is None:
471+
raise ValueError(
472+
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
473+
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
474+
f"Please ensure that you are compliant with the license agreement."
475+
)
476+
479477
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
480478
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
481479

src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
187187

188188
model_cpu_offload_seq = "text_encoder->transformer->vae"
189189
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
190-
_optional_components = ["safety_checker"]
191190

192191
def __init__(
193192
self,
@@ -196,19 +195,12 @@ def __init__(
196195
transformer: CosmosTransformer3DModel,
197196
vae: AutoencoderKLCosmos,
198197
scheduler: EDMEulerScheduler,
199-
safety_checker: CosmosSafetyChecker = None,
200-
requires_safety_checker: bool = True,
198+
safety_checker: CosmosSafetyChecker,
201199
):
202200
super().__init__()
203201

204-
if requires_safety_checker and safety_checker is None:
205-
safety_checker = CosmosSafetyChecker()
206202
if safety_checker is None:
207-
logger.warning(
208-
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. This "
209-
f"is in violation of the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
210-
f"Please ensure that you are compliant with the license agreement."
211-
)
203+
safety_checker = CosmosSafetyChecker()
212204

213205
self.register_modules(
214206
vae=vae,
@@ -218,7 +210,6 @@ def __init__(
218210
scheduler=scheduler,
219211
safety_checker=safety_checker,
220212
)
221-
self.register_to_config(requires_safety_checker=requires_safety_checker)
222213

223214
self.vae_scale_factor_temporal = (
224215
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
@@ -591,6 +582,13 @@ def __call__(
591582
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
592583
"""
593584

585+
if self.safety_checker is None:
586+
raise ValueError(
587+
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
588+
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
589+
f"Please ensure that you are compliant with the license agreement."
590+
)
591+
594592
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
595593
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
596594

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# ===== This file is an implementation of a dummy guardrail for the fast tests =====
16+
17+
from typing import Union
18+
19+
import numpy as np
20+
import torch
21+
22+
from diffusers.configuration_utils import ConfigMixin
23+
from diffusers.models.modeling_utils import ModelMixin
24+
25+
26+
class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
27+
def __init__(self) -> None:
28+
super().__init__()
29+
30+
self._dtype = torch.float32
31+
32+
def check_text_safety(self, prompt: str) -> bool:
33+
return True
34+
35+
def check_video_safety(self, frames: np.ndarray) -> np.ndarray:
36+
return frames
37+
38+
def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None) -> None:
39+
self._dtype = dtype
40+
41+
@property
42+
def device(self) -> torch.device:
43+
return None
44+
45+
@property
46+
def dtype(self) -> torch.dtype:
47+
return self._dtype

tests/pipelines/cosmos/test_cosmos.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
import json
17+
import os
18+
import tempfile
1619
import unittest
1720

1821
import numpy as np
@@ -24,13 +27,21 @@
2427

2528
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
2629
from ..test_pipelines_common import PipelineTesterMixin, to_np
30+
from .cosmos_guardrail import DummyCosmosSafetyChecker
2731

2832

2933
enable_full_determinism()
3034

3135

36+
class CosmosPipelineWrapper(CosmosPipeline):
37+
@staticmethod
38+
def from_pretrained(*args, **kwargs):
39+
kwargs["safety_checker"] = DummyCosmosSafetyChecker()
40+
return CosmosPipeline.from_pretrained(*args, **kwargs)
41+
42+
3243
class CosmosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
33-
pipeline_class = CosmosPipeline
44+
pipeline_class = CosmosPipelineWrapper
3445
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
3546
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
3647
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
@@ -106,8 +117,7 @@ def get_dummy_components(self):
106117
"text_encoder": text_encoder,
107118
"tokenizer": tokenizer,
108119
# We cannot run the Cosmos Guardrail for fast tests due to the large model size
109-
"safety_checker": None,
110-
"requires_safety_checker": False,
120+
"safety_checker": DummyCosmosSafetyChecker(),
111121
}
112122
return components
113123

@@ -149,13 +159,6 @@ def test_inference(self):
149159
max_diff = np.abs(generated_video - expected_video).max()
150160
self.assertLessEqual(max_diff, 1e10)
151161

152-
def test_components_function(self):
153-
init_components = self.get_dummy_components()
154-
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
155-
pipe = self.pipeline_class(**init_components, requires_safety_checker=False)
156-
self.assertTrue(hasattr(pipe, "components"))
157-
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
158-
159162
def test_callback_inputs(self):
160163
sig = inspect.signature(self.pipeline_class.__call__)
161164
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
@@ -216,7 +219,7 @@ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
216219
assert output.abs().sum() < 1e10
217220

218221
def test_inference_batch_single_identical(self):
219-
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
222+
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
220223

221224
def test_attention_slicing_forward_pass(
222225
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
@@ -282,3 +285,61 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
282285
expected_diff_max,
283286
"VAE tiling should not affect the inference results",
284287
)
288+
289+
def test_serialization_with_variants(self):
290+
components = self.get_dummy_components()
291+
pipe = self.pipeline_class(**components)
292+
model_components = [
293+
component_name
294+
for component_name, component in pipe.components.items()
295+
if isinstance(component, torch.nn.Module)
296+
]
297+
model_components.remove("safety_checker")
298+
variant = "fp16"
299+
300+
with tempfile.TemporaryDirectory() as tmpdir:
301+
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
302+
303+
with open(f"{tmpdir}/model_index.json", "r") as f:
304+
config = json.load(f)
305+
306+
for subfolder in os.listdir(tmpdir):
307+
if not os.path.isfile(subfolder) and subfolder in model_components:
308+
folder_path = os.path.join(tmpdir, subfolder)
309+
is_folder = os.path.isdir(folder_path) and subfolder in config
310+
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
311+
312+
def test_torch_dtype_dict(self):
313+
components = self.get_dummy_components()
314+
if not components:
315+
self.skipTest("No dummy components defined.")
316+
317+
pipe = self.pipeline_class(**components)
318+
319+
specified_key = next(iter(components.keys()))
320+
321+
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
322+
pipe.save_pretrained(tmpdirname, safe_serialization=False)
323+
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
324+
loaded_pipe = self.pipeline_class.from_pretrained(
325+
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
326+
)
327+
328+
for name, component in loaded_pipe.components.items():
329+
if name == "safety_checker":
330+
continue
331+
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
332+
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
333+
self.assertEqual(
334+
component.dtype,
335+
expected_dtype,
336+
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
337+
)
338+
339+
@unittest.skip(
340+
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
341+
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
342+
"too large and slow to run on CI."
343+
)
344+
def test_encode_prompt_works_in_isolation(self):
345+
pass

0 commit comments

Comments
 (0)