@@ -1039,6 +1039,7 @@ def __init__(
1039
1039
fine_attn_depth = 2 ,
1040
1040
fine_attn_dim_head = 32 ,
1041
1041
fine_attn_heads = 8 ,
1042
+ fine_cross_attend_text = False ,
1042
1043
pad_id = - 1 ,
1043
1044
num_sos_tokens = None ,
1044
1045
condition_on_text = False ,
@@ -1137,6 +1138,8 @@ def __init__(
1137
1138
1138
1139
# decoding the vertices, 2-stage hierarchy
1139
1140
1141
+ self .fine_cross_attend_text = condition_on_text and fine_cross_attend_text
1142
+
1140
1143
self .fine_decoder = Decoder (
1141
1144
dim = dim_fine ,
1142
1145
depth = fine_attn_depth ,
@@ -1145,6 +1148,9 @@ def __init__(
1145
1148
attn_flash = flash_attn ,
1146
1149
attn_dropout = dropout ,
1147
1150
ff_dropout = dropout ,
1151
+ cross_attend = self .fine_cross_attend_text ,
1152
+ cross_attn_dim_context = cross_attn_dim_context ,
1153
+ cross_attn_num_mem_kv = cross_attn_num_mem_kv ,
1148
1154
** attn_kwargs
1149
1155
)
1150
1156
@@ -1512,8 +1518,17 @@ def forward_on_codes(
1512
1518
if exists (fine_cache ):
1513
1519
for attn_intermediate in fine_cache .attn_intermediates :
1514
1520
ck , cv = attn_intermediate .cached_kv
1515
- ck , cv = map (lambda t : rearrange (t , '(b nf) ... -> b nf ...' , b = batch ), (ck , cv ))
1516
- ck , cv = map (lambda t : t [:, - 1 , :, :curr_vertex_pos ], (ck , cv ))
1521
+ ck , cv = [rearrange (t , '(b nf) ... -> b nf ...' , b = batch ) for t in (ck , cv )]
1522
+
1523
+ # when operating on the cached key / values, treat self attention and cross attention differently
1524
+
1525
+ layer_type = attn_intermediate .layer_type
1526
+
1527
+ if layer_type == 'a' :
1528
+ ck , cv = [t [:, - 1 , :, :curr_vertex_pos ] for t in (ck , cv )]
1529
+ elif layer_type == 'c' :
1530
+ ck , cv = [t [:, - 1 , ...] for t in (ck , cv )]
1531
+
1517
1532
attn_intermediate .cached_kv = (ck , cv )
1518
1533
1519
1534
num_faces = fine_vertex_codes .shape [1 ]
@@ -1524,9 +1539,25 @@ def forward_on_codes(
1524
1539
if one_face :
1525
1540
fine_vertex_codes = fine_vertex_codes [:, :(curr_vertex_pos + 1 )]
1526
1541
1542
+ # handle maybe cross attention conditioning of fine transformer with text
1543
+
1544
+ fine_attn_context_kwargs = dict ()
1545
+
1546
+ if self .fine_cross_attend_text :
1547
+ repeat_batch = fine_vertex_codes .shape [0 ] // text_embed .shape [0 ]
1548
+
1549
+ text_embed = repeat (text_embed , 'b ... -> (b r) ...' , r = repeat_batch )
1550
+ text_mask = repeat (text_mask , 'b ... -> (b r) ...' , r = repeat_batch )
1551
+
1552
+ fine_attn_context_kwargs = dict (
1553
+ context = text_embed ,
1554
+ context_mask = text_mask
1555
+ )
1556
+
1527
1557
attended_vertex_codes , fine_cache = self .fine_decoder (
1528
1558
fine_vertex_codes ,
1529
1559
cache = fine_cache ,
1560
+ ** fine_attn_context_kwargs ,
1530
1561
return_hiddens = True
1531
1562
)
1532
1563
0 commit comments