Skip to content

Commit fe9b20f

Browse files
committed
remove global linear attention feature
1 parent 83f64c4 commit fe9b20f

File tree

3 files changed

+2
-74
lines changed

3 files changed

+2
-74
lines changed

se3_transformer_pytorch/se3_transformer_pytorch.py

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -660,53 +660,6 @@ def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos
660660

661661
return self.to_out(outputs)
662662

663-
# global linear attention - only for type 0
664-
665-
class GlobalLinearAttention(nn.Module):
666-
def __init__(
667-
self,
668-
fiber,
669-
dim_head = 64,
670-
heads = 8,
671-
**kwargs
672-
):
673-
super().__init__()
674-
inner_dim = dim_head * heads
675-
self.scale = dim_head ** -0.5
676-
self.heads = heads
677-
678-
self.to_qkv = nn.Linear(fiber[0], inner_dim * 3, bias = False)
679-
self.to_out = nn.Linear(inner_dim, fiber[0])
680-
681-
def forward(self, features, edge_info, rel_dist, basis, global_feats = None, pos_emb = None, mask = None):
682-
h = self.heads
683-
device, dtype = get_tensor_device_and_dtype(features)
684-
685-
x = features['0'] # only working on type 0 features for global linear attention
686-
x = rearrange(x, '... () -> ...')
687-
688-
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
689-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
690-
691-
if exists(mask):
692-
mask = rearrange(mask, 'b n -> b () n ()')
693-
k = k.masked_fill(~mask, -torch.finfo(k.dtype).max)
694-
v = v.masked_fill(~mask, 0.)
695-
696-
q = q.softmax(dim = -1)
697-
k = k.softmax(dim = -2)
698-
699-
q *= self.scale
700-
701-
context = einsum('b h n d, b h n e -> b h d e', k, v)
702-
attn_out = einsum('b h d e, b h n d -> b h n e', context, q)
703-
attn_out = rearrange(attn_out, 'b h n d -> b n (h d)')
704-
attn_out = self.to_out(attn_out)
705-
706-
out = map_values(lambda *args: 0, features)
707-
out['0'] = rearrange(attn_out, '... -> ... ()')
708-
return out
709-
710663
class AttentionBlockSE3(nn.Module):
711664
def __init__(
712665
self,
@@ -1027,7 +980,6 @@ def __init__(
1027980
tie_key_values = False,
1028981
rotary_position = False,
1029982
rotary_rel_dist = False,
1030-
global_linear_attn_every = 0,
1031983
norm_gated_scale = False,
1032984
use_egnn = False,
1033985
egnn_hidden_dim = 32,
@@ -1153,8 +1105,7 @@ def __init__(
11531105
else:
11541106
layers = nn.ModuleList([])
11551107
for ind in range(depth):
1156-
use_global_linear_attn = global_linear_attn_every > 0 and (ind % global_linear_attn_every) == 0
1157-
attention_klass = default_attention_klass if not use_global_linear_attn else GlobalLinearAttention
1108+
attention_klass = default_attention_klass
11581109

11591110
layers.append(nn.ModuleList([
11601111
AttentionBlockSE3(fiber_hidden, heads = heads, dim_head = dim_head, attend_self = attend_self, edge_dim = edge_dim, fourier_encode_dist = fourier_encode_dist, rel_dist_num_fourier_features = rel_dist_num_fourier_features, use_null_kv = use_null_kv, splits = splits, global_feats_dim = global_feats_dim, linear_proj_keys = linear_proj_keys, attention_klass = attention_klass, tie_key_values = tie_key_values, norm_gated_scale = norm_gated_scale),

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'se3-transformer-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '0.8.11',
7+
version = '0.8.12',
88
license='MIT',
99
description = 'SE3 Transformer - Pytorch',
1010
author = 'Phil Wang',

tests/test_equivariance.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -184,29 +184,6 @@ def test_equivariance_with_egnn_backbone():
184184
diff = (out1 - out2).max()
185185
assert diff < 1e-4, 'is not equivariant'
186186

187-
def test_equivariance_with_global_linear_attn():
188-
model = SE3Transformer(
189-
dim = 64,
190-
depth = 4,
191-
attend_self = True,
192-
num_neighbors = 4,
193-
num_degrees = 2,
194-
output_degrees = 2,
195-
fourier_encode_dist = True,
196-
global_linear_attn_every = 2
197-
)
198-
199-
feats = torch.randn(1, 32, 64)
200-
coors = torch.randn(1, 32, 3)
201-
mask = torch.ones(1, 32).bool()
202-
203-
R = rot(15, 0, 45)
204-
out1 = model(feats, coors @ R, mask, return_type = 1)
205-
out2 = model(feats, coors, mask, return_type = 1) @ R
206-
207-
diff = (out1 - out2).max()
208-
assert diff < 1e-4, 'is not equivariant'
209-
210187
def test_rotary():
211188
model = SE3Transformer(
212189
dim = 64,

0 commit comments

Comments
 (0)