Replies: 1 comment
-
Alternatively if I stop trying to move everything to CPU, i.e if I run the following code... import torch
from musiclm_pytorch import MusicLM, MuLaNTrainer
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer, MuLaNEmbedQuantizer
from torchaudio.transforms import Spectrogram
import random
import numpy as np
audio_transformer = AudioSpectrogramTransformer(
dim=512,
depth=6,
heads=8,
dim_head=64,
spec_n_fft=128,
spec_win_length=24,
spec_aug_stretch_factor=0.8
)
text_transformer = TextTransformer(
dim=512,
depth=6,
heads=8,
dim_head=64,
max_seq_len=512
)
mulan = MuLaN(
audio_transformer=audio_transformer,
text_transformer=text_transformer
)
mulan.eval()
wavs = torch.randn(5, 1024)
texts = torch.randint(0, 20000, (5, 512))
print(wavs.shape, texts.shape)
from torch.utils.data import Dataset
class TextAudioDataset(Dataset):
def __init__(self, wavs, texts):
super().__init__()
self.wavs = wavs
self.texts = texts
def __len__(self):
if len(self.wavs) != len(self.texts):
return -1
else:
return len(self.wavs)
def __getitem__(self, idx):
return self.wavs[idx], self.texts[idx]
trainer = MuLaNTrainer(
mulan=mulan,
dataset=TextAudioDataset(wavs, texts),
batch_size=2
)
trainer.train() I get the following error
I have tried setting environment variables import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "" But the |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to train this on CPU (on a small dataset) to validate some ideas.
I am getting the following error:
Is there a way to run the entire thing on CPU?
Beta Was this translation helpful? Give feedback.
All reactions