Skip to content

Commit 4c833a2

Browse files
committed
Add option to disable caching
1 parent b76b78e commit 4c833a2

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

dalle_pytorch/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
257257
nn.Dropout(dropout)
258258
)
259259

260-
def forward(self, x, mask = None, rotary_pos_emb = None):
260+
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
261261
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
262262
softmax = torch.softmax if not self.stable else stable_softmax
263263

dalle_pytorch/dalle_pytorch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,8 @@ def generate_images(
486486
filter_thres = 0.5,
487487
temperature = 1.,
488488
img = None,
489-
num_init_img_tokens = None
489+
num_init_img_tokens = None,
490+
use_cache = False,
490491
):
491492
vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
492493
total_len = text_seq_len + image_seq_len
@@ -505,7 +506,7 @@ def generate_images(
505506
indices = indices[:, :num_img_tokens]
506507
out = torch.cat((out, indices), dim = -1)
507508

508-
cache = {}
509+
cache = {} if use_cache else None
509510
for cur_len in range(out.shape[1], total_len):
510511
is_image = cur_len >= text_seq_len
511512

0 commit comments

Comments
 (0)