Skip to content

Commit dac71b1

Browse files
committed
Add initial implementation of a text processor analogous to VaeImageProcessor
1 parent a602b06 commit dac71b1

File tree

2 files changed

+167
-6
lines changed

2 files changed

+167
-6
lines changed

src/diffusers/pipelines/dream/pipeline_dream.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from ...models import DreamTransformer1DModel
77
from ...schedulers import DreamMaskedDiffusionScheduler
8+
from ...text_processor import TokenizerTextProcessor
89
from ...utils import is_torch_xla_available, logging
910
from ..pipeline_utils import DiffusionPipeline
1011
from .pipeline_output import DreamTextPipelineOutput
@@ -47,6 +48,7 @@ def __init__(
4748
scheduler=scheduler,
4849
)
4950

51+
self.text_processor = TokenizerTextProcessor()
5052
# 131072 in original code
5153
self.tokenizer_max_length = (
5254
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 512
@@ -185,7 +187,7 @@ def __call__(
185187
max_sequence_length: int = 512,
186188
apply_chat_template: bool = False,
187189
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
188-
output_type: str = "pil", # TODO: replace with options appropriate for text
190+
output_type: str = "str",
189191
return_dict: bool = True,
190192
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
191193
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -222,8 +224,8 @@ def __call__(
222224
cross_attention_kwargs (`dict`, *optional*):
223225
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
224226
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
225-
output_type (`str`, *optional*, defaults to `"pil"`):
226-
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
227+
output_type (`str`, *optional*, defaults to `"str"`):
228+
The output format of the generated text. Choose between `str`, `np`, `pt`, or `latents`.
227229
return_dict (`bool`, *optional*, defaults to `True`):
228230
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
229231
plain tuple.
@@ -364,9 +366,7 @@ def __call__(
364366
if output_type == "latent":
365367
texts = latents
366368
else:
367-
# TODO: should there be a text_processor class analogous to e.g. VaeImageProcessor???
368-
texts = self.tokenizer.batch_decode(latents)
369-
# TODO: if prompt or other conditioning is supplied, remove prompts from generated texts???
369+
texts = self.text_processor.postprocess(self.tokenizer, latents, output_type=output_type)
370370

371371
# Offload all models
372372
self.maybe_free_model_hooks()

src/diffusers/text_processor.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import warnings
2+
from typing import Dict, List, Optional, Union
3+
4+
import numpy as np
5+
import torch
6+
from transformers import PreTrainedTokenizerBase
7+
8+
from .configuration_utils import ConfigMixin, register_to_config
9+
from .utils import CONFIG_NAME
10+
11+
12+
class TokenizerTextProcessor(ConfigMixin):
13+
"""
14+
Text processor for text models using a `transformers`-style `PreTrainedTokenizerBase`.
15+
"""
16+
17+
config_name = CONFIG_NAME
18+
19+
@register_to_config
20+
def __init__(self, apply_chat_template: bool = False):
21+
super().__init__()
22+
23+
@staticmethod
24+
def numpy_to_pt(text_ids: np.ndarray) -> torch.Tensor:
25+
# text_ids shape: [batch_size, seq_len]
26+
text_ids = torch.from_numpy(text_ids)
27+
return text_ids
28+
29+
@staticmethod
30+
def pt_to_numpy(text_ids: torch.Tensor) -> np.ndarray:
31+
# text_ids shape: [batch_size, seq_len]
32+
text_ids = text_ids.cpu().numpy()
33+
return text_ids
34+
35+
@staticmethod
36+
def is_chat_conversation(
37+
text: Union[str, List[str], List[Dict[str, str]], List[List[Dict[str, str]]]]
38+
) -> bool:
39+
is_chat_conversation = False
40+
if isinstance(text, list):
41+
if isinstance(text[0], dict):
42+
is_chat_conversation = True # List[Dict[str, str]]
43+
elif isinstance(text[0], list) and isinstance(text[0][0], dict):
44+
is_chat_conversation = True # List[List[Dict[str, str]]]
45+
elif not isinstance(text[0], str):
46+
raise ValueError(
47+
f"`text` should either be a list of str or a list of Dict[str, str] representing chat history, but "
48+
f"is a list of type {type(text[0])}"
49+
)
50+
return is_chat_conversation
51+
52+
def preprocess(
53+
self,
54+
tokenizer: PreTrainedTokenizerBase,
55+
text: Union[str, List[str], List[Dict[str, str]], List[List[Dict[str, str]]]],
56+
apply_chat_template: Optional[bool] = None,
57+
**kwargs,
58+
):
59+
"""
60+
Converts the supplied text to token ids using the tokenizer. This supports normal tokenization via the
61+
tokenizer's `__call__` method and chat tokenization via the `apply_chat_template` method.
62+
63+
Args:
64+
tokenizer (`transformers.PreTrainedTokenizerBase`):
65+
A `transformers`-style fast or slow tokenizer.
66+
text (`str` or `List[str]` or `List[Dict[str, str]]` or `List[List[Dict[str, str]]]`):
67+
The text to be tokenized. If tokenizing normally, should be a `str` or `List[str]`; if using chat
68+
tokenization, should be `List[Dict[str, str]]` or `List[List[Dict[str, str]]]`.
69+
apply_chat_template (`bool`, *optional*, defaults to `None`):
70+
Whether to process the `text` as chat input using `apply_chat_template`. If not set, this will default
71+
to the `apply_chat_template` value set in the config.
72+
kwargs (additional keyword arguments, *optional*):
73+
Keyword arguments as appropriate for `apply_chat_template` or `__call__`, depending on whether chat or
74+
normal tokenization is used; these will be passed to the respective methods above. Note that
75+
`return_tensors` is explicitly set to `pt` when these methods are called.
76+
"""
77+
if apply_chat_template is None:
78+
apply_chat_template = self.config.apply_chat_template
79+
80+
if isinstance(text, str):
81+
text = [text]
82+
83+
is_chat_conversation = self.is_chat_conversation(text)
84+
if not is_chat_conversation and apply_chat_template:
85+
warnings.warn(
86+
"The supplied text is not chat input but apply_chat_template is True. The input will be converted into"
87+
" a simple chat input format.",
88+
UserWarning,
89+
)
90+
text = [{"role": "user", "content": message} for message in text]
91+
92+
if apply_chat_template:
93+
text_inputs = tokenizer.apply_chat_template(text, return_tensors="pt", return_dict=False, **kwargs)
94+
elif is_chat_conversation:
95+
warnings.warn(
96+
"The supplied `text` is in the form of a chat conversation but apply_chat_template is False. The input"
97+
" will be treated as chat input (e.g. processed with `apply_chat_template`).",
98+
UserWarning,
99+
)
100+
text_inputs = tokenizer.apply_chat_template(text, return_tensors="pt", return_dict=False, **kwargs)
101+
else:
102+
# Process normally using the tokenizer's __call__ method
103+
text_inputs = tokenizer(text, return_tensors="pt", **kwargs)
104+
105+
return text_inputs
106+
107+
def postprocess(
108+
self,
109+
tokenizer: PreTrainedTokenizerBase,
110+
text_ids: torch.Tensor,
111+
prompt_ids: Optional[torch.Tensor] = None,
112+
output_type: str = "str",
113+
skip_special_tokens: bool = False,
114+
clean_up_tokenization_spaces: Optional[bool] = None,
115+
**kwargs,
116+
) -> List[str]:
117+
"""
118+
Decodes the generated text_ids using the tokenizer.
119+
120+
Args:
121+
tokenizer (`transformers.PreTrainedTokenizerBase`):
122+
A `transformers`-style fast or slow tokenizer.
123+
text_ids (`torch.Tensor`):
124+
Generated text token ids from the model.
125+
prompt_ids (`torch.Tensor`, *optional*)
126+
Optional prompt token ids; if supplied, these will be used to remove the prompt from the generated
127+
samples.
128+
output_type (`str`, defaults to `"str"`):
129+
The output type of the text, can be one of `str`, `np`, `pt`, or `latent`.
130+
skip_special_tokens (`bool`, defaults to `False`):
131+
Whether to remove special tokens during decoding.
132+
clean_up_tokenization_spaces: (`bool`, *optional*, defaults to `None`):
133+
Whether to clean up tokenization spaces.
134+
kwargs (additional keyword arguments, *optional*):
135+
Additional keyword arguments which will be passed to the tokenizer's underlying `decode` method.
136+
137+
Returns:
138+
`List[str]`:
139+
A list of generated texts as strings.
140+
"""
141+
# text_ids shape: [batch_size, gen_seq_len]
142+
# prompt_ids shape: [batch_size, input_seq_len]
143+
# Assume input_seq_len <= gen_seq_len
144+
if output_type == "latent" or output_type == "pt":
145+
return text_ids
146+
147+
text_ids = self.pt_to_numpy(text_ids)
148+
149+
if output_type == "np":
150+
return text_ids
151+
152+
if prompt_ids is not None:
153+
# Remove prompt_ids from the generations.
154+
texts = [
155+
tokenizer.decode(sample[len(prompt):], skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
156+
for sample, prompt in zip(text_ids, prompt_ids)
157+
]
158+
else:
159+
texts = tokenizer.batch_decode(text_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
160+
161+
return texts

0 commit comments

Comments
 (0)