77# granted to it by virtue of its status as an intergovernmental organisation
88# nor does it submit to any jurisdiction.
99
10-
10+ import numpy as np
1111import torch
1212from torch .utils .checkpoint import checkpoint
1313
@@ -45,6 +45,7 @@ def __init__(
4545 super (StreamEmbedTransformer , self ).__init__ ()
4646
4747 self .num_tokens = num_tokens
48+ self .token_size = token_size
4849 self .num_channels = num_channels
4950 self .dim_in = token_size if mode == "channels" else num_channels
5051 self .dim_embed = dim_embed
@@ -56,8 +57,6 @@ def __init__(
5657
5758 norm = torch .nn .LayerNorm if norm_type == "LayerNorm" else RMSNorm
5859
59- self .embed = torch .nn .Linear (self .dim_in , self .dim_embed )
60-
6160 self .layers = torch .nn .ModuleList ()
6261 for _ in range (self .num_blocks ):
6362 self .layers .append (
@@ -80,6 +79,8 @@ def __init__(
8079 )
8180
8281 if mode == "channels" :
82+ self .embed = torch .nn .Linear (self .dim_in , self .dim_embed )
83+
8384 if self .unembed_mode == "full" :
8485 self .ln_final = norm (num_channels * self .dim_embed )
8586 self .unembed = torch .nn .Linear (
@@ -94,6 +95,11 @@ def __init__(
9495 dim_out = (self .num_tokens * self .dim_out - embed_size_centroids ) // num_channels
9596 self .unembed = torch .nn .ModuleList (
9697 [torch .nn .Linear (dim_embed , dim_out ) for _ in range (num_channels )]
98+ # [
99+ # torch.nn.Sequential(torch.nn.Linear(dim_embed, max(dim_embed//2,4*dim_out)),
100+ # torch.nn.GELU(),
101+ # torch.nn.Linear(max(dim_embed//2,4*dim_out), dim_out)) for _ in range(num_channels)
102+ # ]
97103 )
98104 self .ln_final = torch .nn .ModuleList ([norm (dim_embed ) for _ in range (num_channels )])
99105
@@ -103,9 +109,12 @@ def __init__(
103109 self .forward = self .forward_channels
104110
105111 elif mode == "columns" :
112+ assert embed_size_centroids == 0
113+ self .embed = torch .nn .Linear (self .dim_in , self .dim_embed )
114+
106115 assert self .unembed_mode == "block" # only supported mode at the moment
107116 # padding needed if the unembedded columns cannot be concatenated to dim_out (e.g GPSRO)
108- self .pad = ( self .dim_out - embed_size_centroids ) % token_size
117+ self .pad = self .dim_out % token_size
109118 self .out_pad = torch .nn .Parameter (torch .zeros (self .pad ))
110119 self .unembed = torch .nn .Linear (
111120 self .dim_embed ,
@@ -114,6 +123,13 @@ def __init__(
114123 self .ln_final = norm (dim_out )
115124 self .forward = self .forward_columns
116125
126+ # TODO: factorization when sqrt is not int
127+ dim1 = int (np .sqrt (dim_out ))
128+ assert dim1 * dim1 == dim_out
129+ self .unembed1 = torch .nn .Linear (self .dim_embed , dim1 )
130+ self .unembed_nonlin = torch .nn .GELU ()
131+ self .unembed2 = torch .nn .Linear (self .token_size , dim1 )
132+
117133 else :
118134 assert False
119135
@@ -135,7 +151,7 @@ def forward_channels(self, x_in, centroids):
135151 elif self .unembed_mode == "block" :
136152 out = [
137153 checkpoint (ue , ln (x [:, i ]), use_reentrant = False )
138- for i , (ue , ln ) in enumerate (zip (self .unembed , self .ln_final , strict = False ))
154+ for i , (ue , ln ) in enumerate (zip (self .unembed , self .ln_final , strict = True ))
139155 ]
140156 out = torch .stack (out , dim = 1 ).flatten (- 2 , - 1 )
141157 else :
@@ -153,27 +169,22 @@ def forward_channels(self, x_in, centroids):
153169
154170 return out
155171
156- # @torch.compile( dynamic=True)
157172 def forward_columns (self , x_in , centroids ):
158173 # embed provided input data
159174 x = positional_encoding_harmonic (checkpoint (self .embed , x_in , use_reentrant = False ))
160175
161176 for layer in self .layers :
162177 x = checkpoint (layer , x , use_reentrant = False )
163178
164- # append centroids
165- # unembed and reshape
166- out = checkpoint (self .unembed , x , use_reentrant = False )
167- out = out .flatten (- 2 , - 1 ).reshape (x .shape [0 ], self .num_tokens , - 1 )
168- # TODO: unsqueeze will not work with num_tokens > 1
169- out = torch .cat ([out , self .embed_centroids (centroids ).unsqueeze (1 )], - 1 )
170- # pad to uniform dim_out (that has to be uniform across streams)
171- if self .pad > 0 :
172- out = torch .cat ((out , self .out_pad .repeat ((x .shape [0 ], self .num_tokens , 1 ))), - 1 )
173- # also encode centroids with overlayed positional encoding
179+ out = checkpoint (self .unembed1 , x , use_reentrant = False )
180+ out = self .unembed_nonlin (out )
181+ out = checkpoint (self .unembed2 , out .transpose (- 2 , - 1 ), use_reentrant = False )
182+ out = out .flatten (- 2 , - 1 ).unsqueeze (1 )
183+
184+ # final normalize and dropout
174185 out = self .dropout_final (self .ln_final (out ))
175186
176- return out
187+ return out . to ( torch . float16 )
177188
178189
179190class StreamEmbedLinear (torch .nn .Module ):
0 commit comments