Skip to content

Commit abf7ec5

Browse files
committed
Support receiving a list of generators in the Dream scheduler
1 parent 31977e6 commit abf7ec5

File tree

2 files changed

+131
-14
lines changed

2 files changed

+131
-14
lines changed

src/diffusers/schedulers/scheduling_dream.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from ..configuration_utils import ConfigMixin, register_to_config
88
from ..utils import BaseOutput, logging
9+
from ..utils.torch_utils import multinomial_tensor, rand_tensor
910
from .scheduling_utils import SchedulerMixin
1011

1112

@@ -75,7 +76,7 @@ def sample_tokens(
7576
top_k: Optional[int] = None,
7677
margin_confidence: bool = False,
7778
neg_entropy: bool = False,
78-
generator: Optional[torch.Generator] = None,
79+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
7980
) -> Tuple[torch.Tensor, torch.Tensor]:
8081
"""
8182
Samples from a sequence of logits of shape [..., vocab_size] and returns both the sampled sequence (as the second
@@ -91,19 +92,11 @@ def sample_tokens(
9192
logits = top_k_logits(logits, top_k)
9293

9394
probs = torch.softmax(logits, dim=-1)
94-
device = probs.device
95-
probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
96-
if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
97-
probs_ = probs_.float() # multinomial is not implemented for cpu half precision
98-
if probs.ndim > 2:
99-
probs_ = probs_.reshape(-1, probs.size(-1)) # [B, L, V] --> [B * L, V]
10095

10196
if temperature > 0:
10297
try:
10398
# Sample x0 ~ Cat(probs)
104-
x0 = torch.multinomial(probs_, 1, generator=generator).to(device=device)
105-
if probs.ndim > 2:
106-
x0 = x0[:, 0].view(*probs.shape[:-1]) # [B * L, 1] --> [B, L]
99+
x0 = multinomial_tensor(probs, 1, generator=generator, device=logits.device)
107100
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) # [B, L]
108101
except:
109102
confidence, x0 = probs.max(dim=-1)
@@ -349,6 +342,7 @@ def step(
349342
timestep: Union[float, torch.Tensor],
350343
sample: torch.Tensor,
351344
generator: Optional[torch.Generator] = None,
345+
noise: Optional[torch.Tensor] = None,
352346
return_dict: bool = True,
353347
) -> Union[DreamMaskedDiffusionSchedulerOutput, Tuple]:
354348
"""
@@ -364,6 +358,9 @@ def step(
364358
A current instance of a sample created by the diffusion process.
365359
generator (`torch.Generator`, *optional*):
366360
A random number generator.
361+
noise (`torch.Tensor`, *optional*):
362+
Allows the noise to be specified directly as an alternative to generating noise with the generator.
363+
Note that this noise should drawn from the uniform distribution over [0, 1].
367364
return_dict (`bool`):
368365
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
369366
tuple.
@@ -396,7 +393,9 @@ def step(
396393

397394
# TODO: mask logits (model_output) beforehand? might make it more efficient?
398395
if self.config.logit_sampling_alg == "origin":
399-
to_unmask_mask = torch.rand(*sample.shape, generator=generator, device=sample.device) < unmask_prob
396+
if noise is None:
397+
noise = rand_tensor(sample.shape, generator=generator, device=sample.device)
398+
to_unmask_mask = noise < unmask_prob
400399
confidence, pred_original_sample = sample_tokens(
401400
model_output, temperature=temperature, top_p=top_p, top_k=top_k, generator=generator
402401
)
@@ -437,9 +436,7 @@ def step(
437436
else:
438437
full_confidence = full_confidence / self.config.alg_temperature
439438
full_confidence = F.softmax(full_confidence, dim=-1)
440-
unmask_indices = torch.multinomial(
441-
full_confidence, num_samples=num_tokens_to_unmask, generator=generator
442-
)
439+
unmask_indices = multinomial_tensor(full_confidence, num_tokens_to_unmask, generator=generator)
443440
unmask_indices = unmask_indices.to(sample.device)
444441

445442
row_indices = torch.arange(sample.size(0), device=sample.device).unsqueeze(1).expand_as(unmask_indices)

src/diffusers/utils/torch_utils.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,126 @@ def randn_tensor(
8686
return latents
8787

8888

89+
def rand_tensor(
90+
shape: Union[Tuple, List],
91+
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
92+
device: Optional[Union[str, "torch.device"]] = None,
93+
dtype: Optional["torch.dtype"] = None,
94+
layout: Optional["torch.layout"] = None,
95+
):
96+
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
97+
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
98+
is always created on the CPU. This is analogous to `randn_tensor`, except it creates random tensors from the
99+
uniform distribution over [0, 1] using `torch.rand`.
100+
"""
101+
# device on which tensor is created defaults to device
102+
if isinstance(device, str):
103+
device = torch.device(device)
104+
rand_device = device
105+
batch_size = shape[0]
106+
107+
layout = layout or torch.strided
108+
device = device or torch.device("cpu")
109+
110+
if generator is not None:
111+
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
112+
if gen_device_type != device.type and gen_device_type == "cpu":
113+
rand_device = "cpu"
114+
if device != "mps":
115+
logger.info(
116+
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
117+
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
118+
f" slightly speed up this function by passing a generator that was created on the {device} device."
119+
)
120+
elif gen_device_type != device.type and gen_device_type == "cuda":
121+
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
122+
123+
# make sure generator list of length 1 is treated like a non-list
124+
if isinstance(generator, list) and len(generator) == 1:
125+
generator = generator[0]
126+
127+
if isinstance(generator, list):
128+
shape = (1,) + shape[1:]
129+
latents = [
130+
torch.rand(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
131+
for i in range(batch_size)
132+
]
133+
latents = torch.cat(latents, dim=0).to(device)
134+
else:
135+
latents = torch.rand(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
136+
137+
return latents
138+
139+
140+
def multinomial_tensor(
141+
logits: torch.Tensor,
142+
num_samples: int,
143+
replacement: bool = False,
144+
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
145+
device: Optional[Union[str, "torch.device"]] = None,
146+
squeeze_trailing_dim: bool = True,
147+
):
148+
"""
149+
Creates a tensor drawn from the multinomial distribution specified by the (possibly unnormalized) probabilities
150+
given by `logits`. This is to analogous to `randn_tensor`, wrapping `torch.multinomial` rather than `torch.randn`.
151+
152+
In general, if `logits` has shape [..., num_categories], where the ... represents leading batch dimensions, the
153+
output will have shape [..., num_samples]. `logits` is assumed to have at least one leading batch dimension.
154+
"""
155+
batch_size = logits.shape[0]
156+
num_cats = logits.shape[-1]
157+
158+
device = device or torch.device("cpu")
159+
160+
if generator is not None:
161+
gen_device = generator.device if not isinstance(generator, list) else generator[0].device
162+
gen_device_type = gen_device.type
163+
if gen_device_type != device.type and gen_device_type == "cpu":
164+
if device != "mps":
165+
logger.info(
166+
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
167+
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
168+
f" slightly speed up this function by passing a generator that was created on the {device} device."
169+
)
170+
elif gen_device_type != device.type and gen_device_type == "cuda":
171+
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
172+
173+
# make sure generator list of length 1 is treated like a non-list
174+
if isinstance(generator, list) and len(generator) == 1:
175+
generator = generator[0]
176+
177+
# Handle the case where generator is on CPU
178+
logits_ = logits.to(gen_device) if generator is not None else logits
179+
180+
# Multinomial is not implemented for half precision on CPU
181+
if logits_.device.type == "cpu" and logits_.dtype != torch.float32:
182+
logits_ = logits_.float()
183+
184+
if isinstance(generator, list):
185+
sample = []
186+
original_shape = logits.shape[1:-1]
187+
for i in range(batch_size):
188+
logits_instance = logits_[i]
189+
if logits_instance.ndim > 2:
190+
logits_instance = logits_instance.reshape(-1, num_cats)
191+
sample_instance = torch.multinomial(logits_instance, num_samples, replacement, generator=generator[i])
192+
if logits_instance.ndim > 2:
193+
sample_instance = sample_instance.view(*original_shape, num_samples)
194+
sample = torch.stack(sample, dim=0).to(device)
195+
else:
196+
if logits.ndim > 2:
197+
original_shape = logits.shape[:-1]
198+
logits_ = logits_.reshape(-1, logits.size(-1))
199+
sample = torch.multinomial(logits_, num_samples, replacement, generator=generator).to(device)
200+
if logits.ndim > 2:
201+
sample = sample.view(*original_shape, num_samples)
202+
203+
if squeeze_trailing_dim:
204+
sample = sample.squeeze(-1)
205+
206+
return sample
207+
208+
89209
def is_compiled_module(module) -> bool:
90210
"""Check whether the module was compiled with torch.compile()"""
91211
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):

0 commit comments

Comments
 (0)