Skip to content

Fine tuning issues #65

@VRichardJP

Description

@VRichardJP

Hi,

I am trying to use fine tune one of the pretrained HTS-AT model for binary classification on a custom dataset. I have already managed to do the exact same thing with a pretrained BEATs model, but somehow I can't make it work with HTS-AT.

Here is a summary of what I do:

  • I create the HTSAT_Swin_Transformer model with the same config than in your ESC-50 fine tuning example, the only difference being num_classes=1 and loss_type = "clip_bce" since I do binary classification
  • I load one of the pretrained checkpoint (e.g. HTSAT_AudioSet_Saved_1.ckpt) and update all the model weights but sed_model.tscam_conv.weights and sed_model.tscam_conv.bias (I have verified all weights are correctly loaded)
  • I freeze all the parameters but tscam_conv ones (4.6K trainable params left)
  • I feed the model batches of raw audio frames (sampled at 32000Hz and zero-padded to fit longest audio clip in the batch) and compute the loss against its 0-1 targets with nn.BCELoss

I follow the exact same process with BEATs, the only difference being the layers names and the input data sample rate (16000Hz). Yet I can't get the HTS-AT model to learn anything. For example here is the val_loss after a few epochs over a few tries (blue is BEATs fine tuning for reference):

image

I have tried with different learning rates, pretrained weights and optimizers but it does not seem to have any effect.

My dataset being composed of roughly 10% of positives, the val_loss of a dummy model outputing a constant value of 0.10 would have an approximate val_loss of 0.27, which is what all my attempts seem to converge toward. Basically, the model is not learning anything from the input here.

The input data looks "normal". For example, here is what the sound of an ambulance looks like after HTSAT preprocessing:

image

    def forward(
        self, x: torch.Tensor, mixup_lambda=None, infer_mode=False
    ):  # out_feat_keys: List[str] = None):
        x = self.spectrogram_extractor(x)  # (batch_size, 1, time_steps, freq_bins)

        fig, axs = plt.subplots(2)
        img = librosa.display.specshow(
            x[0][0].detach().cpu().numpy().T, x_axis="time", y_axis="log", ax=axs[0]
        )
        fig.colorbar(img, ax=axs[0], format="%+2.f dB")
        axs[0].set(title="spectogram")

        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        img = librosa.display.specshow(
            x[0][0].detach().cpu().numpy().T, x_axis="time", y_axis="mel", ax=axs[1]
        )
        fig.colorbar(img, ax=axs[1], format="%+2.f dB")
        axs[1].set(title="logmel")
        plt.show()

        # ...

Am I missing a key detail?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions