Skip to content

Commit 8d7032d

Browse files
committed
readd ability to condition fine transformer with text, more efficient now with cross attention key/value caching
1 parent cfe23c6 commit 8d7032d

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,7 @@ def __init__(
10391039
fine_attn_depth = 2,
10401040
fine_attn_dim_head = 32,
10411041
fine_attn_heads = 8,
1042+
fine_cross_attend_text = False,
10421043
pad_id = -1,
10431044
num_sos_tokens = None,
10441045
condition_on_text = False,
@@ -1137,6 +1138,8 @@ def __init__(
11371138

11381139
# decoding the vertices, 2-stage hierarchy
11391140

1141+
self.fine_cross_attend_text = condition_on_text and fine_cross_attend_text
1142+
11401143
self.fine_decoder = Decoder(
11411144
dim = dim_fine,
11421145
depth = fine_attn_depth,
@@ -1145,6 +1148,9 @@ def __init__(
11451148
attn_flash = flash_attn,
11461149
attn_dropout = dropout,
11471150
ff_dropout = dropout,
1151+
cross_attend = self.fine_cross_attend_text,
1152+
cross_attn_dim_context = cross_attn_dim_context,
1153+
cross_attn_num_mem_kv = cross_attn_num_mem_kv,
11481154
**attn_kwargs
11491155
)
11501156

@@ -1512,8 +1518,17 @@ def forward_on_codes(
15121518
if exists(fine_cache):
15131519
for attn_intermediate in fine_cache.attn_intermediates:
15141520
ck, cv = attn_intermediate.cached_kv
1515-
ck, cv = map(lambda t: rearrange(t, '(b nf) ... -> b nf ...', b = batch), (ck, cv))
1516-
ck, cv = map(lambda t: t[:, -1, :, :curr_vertex_pos], (ck, cv))
1521+
ck, cv = [rearrange(t, '(b nf) ... -> b nf ...', b = batch) for t in (ck, cv)]
1522+
1523+
# when operating on the cached key / values, treat self attention and cross attention differently
1524+
1525+
layer_type = attn_intermediate.layer_type
1526+
1527+
if layer_type == 'a':
1528+
ck, cv = [t[:, -1, :, :curr_vertex_pos] for t in (ck, cv)]
1529+
elif layer_type == 'c':
1530+
ck, cv = [t[:, -1, ...] for t in (ck, cv)]
1531+
15171532
attn_intermediate.cached_kv = (ck, cv)
15181533

15191534
num_faces = fine_vertex_codes.shape[1]
@@ -1524,9 +1539,25 @@ def forward_on_codes(
15241539
if one_face:
15251540
fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)]
15261541

1542+
# handle maybe cross attention conditioning of fine transformer with text
1543+
1544+
fine_attn_context_kwargs = dict()
1545+
1546+
if self.fine_cross_attend_text:
1547+
repeat_batch = fine_vertex_codes.shape[0] // text_embed.shape[0]
1548+
1549+
text_embed = repeat(text_embed, 'b ... -> (b r) ...' , r = repeat_batch)
1550+
text_mask = repeat(text_mask, 'b ... -> (b r) ...', r = repeat_batch)
1551+
1552+
fine_attn_context_kwargs = dict(
1553+
context = text_embed,
1554+
context_mask = text_mask
1555+
)
1556+
15271557
attended_vertex_codes, fine_cache = self.fine_decoder(
15281558
fine_vertex_codes,
15291559
cache = fine_cache,
1560+
**fine_attn_context_kwargs,
15301561
return_hiddens = True
15311562
)
15321563

meshgpt_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.9'
1+
__version__ = '1.2.10'

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
'torchtyping',
3737
'tqdm',
3838
'vector-quantize-pytorch>=1.14.22',
39-
'x-transformers>=1.30.4',
39+
'x-transformers>=1.30.6',
4040
],
4141
classifiers=[
4242
'Development Status :: 4 - Beta',

0 commit comments

Comments
 (0)