-
Notifications
You must be signed in to change notification settings - Fork 71
Description
Hi,
I am trying to use fine tune one of the pretrained HTS-AT model for binary classification on a custom dataset. I have already managed to do the exact same thing with a pretrained BEATs model, but somehow I can't make it work with HTS-AT.
Here is a summary of what I do:
- I create the
HTSAT_Swin_Transformer
model with the same config than in your ESC-50 fine tuning example, the only difference beingnum_classes=1
andloss_type = "clip_bce"
since I do binary classification - I load one of the pretrained checkpoint (e.g.
HTSAT_AudioSet_Saved_1.ckpt
) and update all the model weights butsed_model.tscam_conv.weights
andsed_model.tscam_conv.bias
(I have verified all weights are correctly loaded) - I freeze all the parameters but
tscam_conv
ones (4.6K trainable params left) - I feed the model batches of raw audio frames (sampled at 32000Hz and zero-padded to fit longest audio clip in the batch) and compute the loss against its 0-1 targets with
nn.BCELoss
I follow the exact same process with BEATs, the only difference being the layers names and the input data sample rate (16000Hz). Yet I can't get the HTS-AT model to learn anything. For example here is the val_loss
after a few epochs over a few tries (blue is BEATs fine tuning for reference):
I have tried with different learning rates, pretrained weights and optimizers but it does not seem to have any effect.
My dataset being composed of roughly 10% of positives, the val_loss
of a dummy model outputing a constant value of 0.10 would have an approximate val_loss
of 0.27, which is what all my attempts seem to converge toward. Basically, the model is not learning anything from the input here.
The input data looks "normal". For example, here is what the sound of an ambulance looks like after HTSAT preprocessing:
def forward(
self, x: torch.Tensor, mixup_lambda=None, infer_mode=False
): # out_feat_keys: List[str] = None):
x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
fig, axs = plt.subplots(2)
img = librosa.display.specshow(
x[0][0].detach().cpu().numpy().T, x_axis="time", y_axis="log", ax=axs[0]
)
fig.colorbar(img, ax=axs[0], format="%+2.f dB")
axs[0].set(title="spectogram")
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
img = librosa.display.specshow(
x[0][0].detach().cpu().numpy().T, x_axis="time", y_axis="mel", ax=axs[1]
)
fig.colorbar(img, ax=axs[1], format="%+2.f dB")
axs[1].set(title="logmel")
plt.show()
# ...
Am I missing a key detail?