1
- from typing import Literal , Union , overload
1
+ from typing import Callable , Literal , Union , cast , overload
2
2
3
3
import torch
4
4
from jaxtyping import Float
@@ -32,6 +32,56 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False, local_o
32
32
)
33
33
return decoder_norm
34
34
35
+ def activation_function_factory (self ) -> Callable [[torch .Tensor ], torch .Tensor ]:
36
+ assert self .cfg .act_fn .lower () in [
37
+ "relu" ,
38
+ "topk" ,
39
+ "jumprelu" ,
40
+ "batchtopk" ,
41
+ ], f"Not implemented activation function { self .cfg .act_fn } "
42
+ if self .cfg .act_fn .lower () == "jumprelu" :
43
+
44
+ class STEFunction (torch .autograd .Function ):
45
+ @staticmethod
46
+ def forward (ctx , input : torch .Tensor , log_jumprelu_threshold : torch .Tensor ):
47
+ jumprelu_threshold = log_jumprelu_threshold .exp ()
48
+ jumprelu_threshold = all_reduce_tensor (jumprelu_threshold , aggregate = "sum" )
49
+ ctx .save_for_backward (input , jumprelu_threshold )
50
+ return input .gt (jumprelu_threshold ).to (input .dtype )
51
+
52
+ @staticmethod
53
+ def backward (ctx , * grad_outputs : torch .Tensor ):
54
+ assert len (grad_outputs ) == 1
55
+ grad_output = grad_outputs [0 ]
56
+
57
+ input , jumprelu_threshold = ctx .saved_tensors
58
+ grad_input = torch .zeros_like (input )
59
+ grad_log_jumprelu_threshold_unscaled = torch .where (
60
+ (input - jumprelu_threshold ).abs () < self .cfg .jumprelu_threshold_window * 0.5 ,
61
+ - jumprelu_threshold / self .cfg .jumprelu_threshold_window ,
62
+ 0.0 ,
63
+ )
64
+ grad_log_jumprelu_threshold = (
65
+ grad_log_jumprelu_threshold_unscaled
66
+ / torch .where (
67
+ ((input - jumprelu_threshold ).abs () < self .cfg .jumprelu_threshold_window * 0.5 )
68
+ * (input != 0.0 ),
69
+ input ,
70
+ 1.0 ,
71
+ )
72
+ * grad_output
73
+ )
74
+ grad_log_jumprelu_threshold = grad_log_jumprelu_threshold .sum (
75
+ dim = tuple (range (grad_log_jumprelu_threshold .ndim - 1 ))
76
+ )
77
+
78
+ return grad_input , grad_log_jumprelu_threshold
79
+
80
+ return lambda x : cast (torch .Tensor , STEFunction .apply (x , self .log_jumprelu_threshold ))
81
+
82
+ else :
83
+ return super ().activation_function_factory ()
84
+
35
85
@overload
36
86
def encode (
37
87
self ,
@@ -109,14 +159,14 @@ def encode(
109
159
hidden_pre = self .hook_hidden_pre (hidden_pre )
110
160
111
161
if self .cfg .sparsity_include_decoder_norm :
112
- true_feature_acts = hidden_pre * self ._decoder_norm (
162
+ sparsity_scores = hidden_pre * self ._decoder_norm (
113
163
decoder = self .decoder ,
114
164
local_only = True ,
115
165
)
116
166
else :
117
- true_feature_acts = hidden_pre
167
+ sparsity_scores = hidden_pre
118
168
119
- activation_mask = self .activation_function (true_feature_acts )
169
+ activation_mask = self .activation_function (sparsity_scores )
120
170
feature_acts = hidden_pre * activation_mask
121
171
122
172
feature_acts = self .hook_feature_acts (feature_acts )
@@ -131,7 +181,9 @@ def compute_loss(
131
181
batch : dict [str , torch .Tensor ],
132
182
* ,
133
183
use_batch_norm_mse : bool = False ,
134
- lp : int = 1 ,
184
+ sparsity_loss_type : Literal ["power" , "tanh" , None ] = None ,
185
+ tanh_stretch_coefficient : float = 4.0 ,
186
+ p : int = 1 ,
135
187
return_aux_data : Literal [True ] = True ,
136
188
** kwargs ,
137
189
) -> tuple [
@@ -145,7 +197,9 @@ def compute_loss(
145
197
batch : dict [str , torch .Tensor ],
146
198
* ,
147
199
use_batch_norm_mse : bool = False ,
148
- lp : int = 1 ,
200
+ sparsity_loss_type : Literal ["power" , "tanh" , None ] = None ,
201
+ tanh_stretch_coefficient : float = 4.0 ,
202
+ p : int = 1 ,
149
203
return_aux_data : Literal [False ],
150
204
** kwargs ,
151
205
) -> Float [torch .Tensor , " batch" ]: ...
@@ -162,7 +216,9 @@ def compute_loss(
162
216
) = None ,
163
217
* ,
164
218
use_batch_norm_mse : bool = False ,
165
- lp : int = 1 ,
219
+ sparsity_loss_type : Literal ["power" , "tanh" , None ] = None ,
220
+ tanh_stretch_coefficient : float = 4.0 ,
221
+ p : int = 1 ,
166
222
return_aux_data : bool = True ,
167
223
** kwargs ,
168
224
) -> Union [
@@ -194,25 +250,31 @@ def compute_loss(
194
250
.sqrt ()
195
251
)
196
252
197
- l_rec = l_rec .mean ()
198
- l_rec = all_reduce_tensor (l_rec , aggregate = "mean" )
253
+ l_rec = l_rec .sum (dim = - 1 ).mean ()
199
254
200
255
loss = l_rec
201
256
loss_dict = {
202
257
"l_rec" : l_rec ,
203
258
}
204
259
205
- # l_l1: (batch,)
206
- feature_acts = feature_acts * self ._decoder_norm (
207
- decoder = self .decoder ,
208
- local_only = True ,
209
- )
210
-
211
- if "topk" not in self .cfg .act_fn :
212
- l_lp = torch .norm (feature_acts , p = lp , dim = - 1 )
213
- loss_dict ["l_lp" ] = l_lp
260
+ if sparsity_loss_type == "power" :
261
+ l_s = torch .norm (feature_acts * self ._decoder_norm (decoder = self .decoder ), p = p , dim = - 1 )
262
+ loss_dict ["l_s" ] = self .current_l1_coefficient * l_s .mean ()
214
263
assert self .current_l1_coefficient is not None
215
- loss = loss + self .current_l1_coefficient * l_lp .mean ()
264
+ loss = loss + self .current_l1_coefficient * l_s .mean ()
265
+ elif sparsity_loss_type == "tanh" :
266
+ l_s = torch .tanh (tanh_stretch_coefficient * feature_acts * self ._decoder_norm (decoder = self .decoder )).sum (
267
+ dim = - 1
268
+ )
269
+ loss_dict ["l_s" ] = self .current_l1_coefficient * l_s .mean ()
270
+ assert self .current_l1_coefficient is not None
271
+ loss = loss + self .current_l1_coefficient * l_s .mean ()
272
+ elif sparsity_loss_type is None :
273
+ pass
274
+ else :
275
+ raise ValueError (f"sparsity_loss_type f{ sparsity_loss_type } not supported." )
276
+
277
+ loss = all_reduce_tensor (loss , aggregate = "mean" )
216
278
217
279
if return_aux_data :
218
280
aux_data = {
@@ -229,7 +291,8 @@ def compute_loss(
229
291
230
292
@torch .no_grad ()
231
293
def log_statistics (self ):
232
- return {}
294
+ assert self .dataset_average_activation_norm is not None
295
+ return {f"info/{ k } " : v for k , v in self .dataset_average_activation_norm .items ()}
233
296
234
297
def initialize_with_same_weight_across_layers (self ):
235
298
self .encoder .weight .data = get_tensor_from_specific_rank (self .encoder .weight .data .clone (), src = 0 )
0 commit comments