@@ -137,14 +137,16 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor])
137
137
"sparsity/below_1e-5" : (feature_sparsity < 1e-5 ).sum ().item (),
138
138
"sparsity/below_1e-6" : (feature_sparsity < 1e-6 ).sum ().item (),
139
139
}
140
- if sae .cfg .sae_type == 'crosscoder' :
141
- overall_act_freq_scores = all_reduce_tensor (feature_sparsity , aggregate = 'max' )
142
- wandb_log_dict .update ({
143
- "sparsity/overall_above_1e-1" : (overall_act_freq_scores > 1e-1 ).sum ().item (),
144
- "sparsity/overall_above_1e-2" : (overall_act_freq_scores > 1e-2 ).sum ().item (),
145
- "sparsity/overall_below_1e-5" : (overall_act_freq_scores < 1e-5 ).sum ().item (),
146
- "sparsity/overall_below_1e-6" : (overall_act_freq_scores < 1e-6 ).sum ().item (),
147
- })
140
+ if sae .cfg .sae_type == "crosscoder" :
141
+ overall_act_freq_scores = all_reduce_tensor (feature_sparsity , aggregate = "max" )
142
+ wandb_log_dict .update (
143
+ {
144
+ "sparsity/overall_above_1e-1" : (overall_act_freq_scores > 1e-1 ).sum ().item (),
145
+ "sparsity/overall_above_1e-2" : (overall_act_freq_scores > 1e-2 ).sum ().item (),
146
+ "sparsity/overall_below_1e-5" : (overall_act_freq_scores < 1e-5 ).sum ().item (),
147
+ "sparsity/overall_below_1e-6" : (overall_act_freq_scores < 1e-6 ).sum ().item (),
148
+ }
149
+ )
148
150
149
151
self .wandb_logger .log (wandb_log_dict , step = self .cur_step + 1 )
150
152
log_info ["act_freq_scores" ] = torch .zeros_like (log_info ["act_freq_scores" ])
@@ -161,7 +163,11 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor])
161
163
wandb_log_dict = {
162
164
# losses
163
165
"losses/mse_loss" : l_rec .item (),
164
- ** ({"losses/sparsity_loss" : log_info ["l_s" ].mean ().item ()} if log_info .get ("l_s" , None ) is not None else {}),
166
+ ** (
167
+ {"losses/sparsity_loss" : log_info ["l_s" ].mean ().item ()}
168
+ if log_info .get ("l_s" , None ) is not None
169
+ else {}
170
+ ),
165
171
"losses/overall_loss" : log_info ["loss" ].item (),
166
172
# variance explained
167
173
"metrics/explained_variance" : explained_variance .mean ().item (),
@@ -179,10 +185,16 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor])
179
185
"details/n_training_tokens" : self .cur_tokens ,
180
186
}
181
187
wandb_log_dict .update (sae .log_statistics ())
182
- if sae .cfg .sae_type == 'crosscoder' :
183
- wandb_log_dict .update ({
184
- "metrics/overall_l0" : all_reduce_tensor (log_info ["feature_acts" ], aggregate = 'max' ).gt (0 ).float ().sum (- 1 ).mean ()
185
- })
188
+ if sae .cfg .sae_type == "crosscoder" :
189
+ wandb_log_dict .update (
190
+ {
191
+ "metrics/overall_l0" : all_reduce_tensor (log_info ["feature_acts" ], aggregate = "max" )
192
+ .gt (0 )
193
+ .float ()
194
+ .sum (- 1 )
195
+ .mean ()
196
+ }
197
+ )
186
198
elif sae .cfg .sae_type == "mixcoder" :
187
199
assert isinstance (sae , MixCoder )
188
200
for modality , (start , end ) in sae .modality_index .items ():
0 commit comments