Skip to content

Commit 6c5de95

Browse files
authored
Merge pull request #18 from argmaxinc/long_context_fix
Truncating long prompts
2 parents cf57983 + 389ff0d commit 6c5de95

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

python/src/diffusionkit/mlx/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
"FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
4141
}
4242

43+
T5_MAX_LENGTH = {
44+
"stable-diffusion-3-medium": 512,
45+
"FLUX.1-schnell": 256,
46+
}
47+
4348

4449
class DiffusionKitInferenceContext(AppleSiliconContextMixin, InferenceContextSpec):
4550
def code_spec(self):
@@ -138,7 +143,9 @@ def set_up_t5(self):
138143
low_memory_mode=self.low_memory_mode,
139144
)
140145
if not hasattr(self, "t5_tokenizer") or self.t5_tokenizer is None:
141-
self.t5_tokenizer = load_t5_tokenizer()
146+
self.t5_tokenizer = load_t5_tokenizer(
147+
max_context_length=T5_MAX_LENGTH[self.model_version]
148+
)
142149
self.use_t5 = True
143150

144151
def unload_t5(self):

python/src/diffusionkit/mlx/model_io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,6 @@ def load_tokenizer(
871871
return Tokenizer(bpe_ranks, vocab, pad_with_eos)
872872

873873

874-
def load_t5_tokenizer():
874+
def load_t5_tokenizer(max_context_length: int = 256):
875875
config = T5Config.from_pretrained("google/t5-v1_1-xxl")
876-
return T5Tokenizer(config)
876+
return T5Tokenizer(config, max_context_length)

python/src/diffusionkit/mlx/tokenizer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
import mlx.core as mx
66
import numpy as np
77
import regex
8+
from argmaxtools.utils import get_logger
89
from transformers import AutoTokenizer, T5Config
910

11+
logger = get_logger(__name__)
12+
1013

1114
class Tokenizer:
1215
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
@@ -101,6 +104,14 @@ def tokenize(self, text, prepend_bos=True, append_eos=True):
101104

102105
# Map to token ids and return
103106
tokens = [self.vocab[t] for t in bpe_tokens]
107+
108+
# Truncate
109+
max_length = self.max_length - int(prepend_bos) - int(append_eos)
110+
if len(tokens) > max_length:
111+
tokens = tokens[:max_length]
112+
logger.warning(
113+
f"Length of tokens exceeds {self.max_length}. Truncating to {self.max_length}."
114+
)
104115
if prepend_bos:
105116
tokens = [self.bos_token] + tokens
106117
if append_eos:
@@ -110,16 +121,16 @@ def tokenize(self, text, prepend_bos=True, append_eos=True):
110121

111122

112123
class T5Tokenizer:
113-
def __init__(self, config: T5Config):
124+
def __init__(self, config: T5Config, max_context_length: int):
125+
self.max_length = max_context_length
114126
self._decoder_start_id = config.decoder_start_token_id
115127
self._tokenizer = AutoTokenizer.from_pretrained(
116128
"google/t5-v1_1-xxl",
117129
legacy=False,
118-
model_max_length=getattr(config, "n_positions", 512),
130+
model_max_length=self.max_length,
119131
)
120132

121133
self.pad_to_max_length = True
122-
self.max_length = 77
123134
self.pad_with_eos = False
124135

125136
@property
@@ -136,6 +147,8 @@ def encode(self, s: str) -> mx.array:
136147
s,
137148
return_tensors="np",
138149
return_attention_mask=False,
150+
max_length=self.max_length,
151+
truncation=True,
139152
)["input_ids"]
140153
)
141154

0 commit comments

Comments
 (0)