@@ -660,53 +660,6 @@ def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos
660
660
661
661
return self .to_out (outputs )
662
662
663
- # global linear attention - only for type 0
664
-
665
- class GlobalLinearAttention (nn .Module ):
666
- def __init__ (
667
- self ,
668
- fiber ,
669
- dim_head = 64 ,
670
- heads = 8 ,
671
- ** kwargs
672
- ):
673
- super ().__init__ ()
674
- inner_dim = dim_head * heads
675
- self .scale = dim_head ** - 0.5
676
- self .heads = heads
677
-
678
- self .to_qkv = nn .Linear (fiber [0 ], inner_dim * 3 , bias = False )
679
- self .to_out = nn .Linear (inner_dim , fiber [0 ])
680
-
681
- def forward (self , features , edge_info , rel_dist , basis , global_feats = None , pos_emb = None , mask = None ):
682
- h = self .heads
683
- device , dtype = get_tensor_device_and_dtype (features )
684
-
685
- x = features ['0' ] # only working on type 0 features for global linear attention
686
- x = rearrange (x , '... () -> ...' )
687
-
688
- q , k , v = self .to_qkv (x ).chunk (3 , dim = - 1 )
689
- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), (q , k , v ))
690
-
691
- if exists (mask ):
692
- mask = rearrange (mask , 'b n -> b () n ()' )
693
- k = k .masked_fill (~ mask , - torch .finfo (k .dtype ).max )
694
- v = v .masked_fill (~ mask , 0. )
695
-
696
- q = q .softmax (dim = - 1 )
697
- k = k .softmax (dim = - 2 )
698
-
699
- q *= self .scale
700
-
701
- context = einsum ('b h n d, b h n e -> b h d e' , k , v )
702
- attn_out = einsum ('b h d e, b h n d -> b h n e' , context , q )
703
- attn_out = rearrange (attn_out , 'b h n d -> b n (h d)' )
704
- attn_out = self .to_out (attn_out )
705
-
706
- out = map_values (lambda * args : 0 , features )
707
- out ['0' ] = rearrange (attn_out , '... -> ... ()' )
708
- return out
709
-
710
663
class AttentionBlockSE3 (nn .Module ):
711
664
def __init__ (
712
665
self ,
@@ -1027,7 +980,6 @@ def __init__(
1027
980
tie_key_values = False ,
1028
981
rotary_position = False ,
1029
982
rotary_rel_dist = False ,
1030
- global_linear_attn_every = 0 ,
1031
983
norm_gated_scale = False ,
1032
984
use_egnn = False ,
1033
985
egnn_hidden_dim = 32 ,
@@ -1153,8 +1105,7 @@ def __init__(
1153
1105
else :
1154
1106
layers = nn .ModuleList ([])
1155
1107
for ind in range (depth ):
1156
- use_global_linear_attn = global_linear_attn_every > 0 and (ind % global_linear_attn_every ) == 0
1157
- attention_klass = default_attention_klass if not use_global_linear_attn else GlobalLinearAttention
1108
+ attention_klass = default_attention_klass
1158
1109
1159
1110
layers .append (nn .ModuleList ([
1160
1111
AttentionBlockSE3 (fiber_hidden , heads = heads , dim_head = dim_head , attend_self = attend_self , edge_dim = edge_dim , fourier_encode_dist = fourier_encode_dist , rel_dist_num_fourier_features = rel_dist_num_fourier_features , use_null_kv = use_null_kv , splits = splits , global_feats_dim = global_feats_dim , linear_proj_keys = linear_proj_keys , attention_klass = attention_klass , tie_key_values = tie_key_values , norm_gated_scale = norm_gated_scale ),
0 commit comments