From 36e9a014d22fd74763967915592002f5d9b775d6 Mon Sep 17 00:00:00 2001 From: Salman Date: Fri, 21 Feb 2025 19:59:44 +0000 Subject: [PATCH] OOM fix --- validator-api/validator_api/dataset_upload.py | 67 +++++++++++++++---- 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/validator-api/validator_api/dataset_upload.py b/validator-api/validator_api/dataset_upload.py index 1880f57..76a0018 100644 --- a/validator-api/validator_api/dataset_upload.py +++ b/validator-api/validator_api/dataset_upload.py @@ -3,6 +3,7 @@ from datetime import datetime import random import tempfile +import os from datasets import Dataset, Audio from huggingface_hub import HfApi @@ -101,20 +102,32 @@ def submit(self) -> None: self.current_batch = self.current_batch[self.desired_batch_size:] self.desired_batch_size = get_random_batch_size(config.UPLOAD_BATCH_SIZE) - class AudioDatasetUploader: def __init__(self): self.current_batch = [] self.min_batch_size = 8 self.desired_batch_size = get_random_batch_size(config.UPLOAD_AUDIO_BATCH_SIZE) + self.temp_dir = self.make_temp_dir() + + def make_temp_dir(self): + return tempfile.mkdtemp() - def convert_audio_to_wav(self, audio_bytes: str) -> bytes: - temp_audiofile = tempfile.NamedTemporaryFile(suffix=".wav") + def clean_temp_dir(self, batch_to_clean: List[dict]): + if self.temp_dir: + for item in batch_to_clean: + try: + os.unlink(item["audio"]["path"]) + except: + pass + + def convert_audio_to_wav(self, audio_bytes: str) -> str: + temp_audiofile = tempfile.NamedTemporaryFile(suffix=".wav", dir=self.temp_dir, delete=False) audio_bytes = base64.b64decode(audio_bytes) with open(temp_audiofile.name, "wb") as f: f.write(audio_bytes) - return temp_audiofile.read() + + return temp_audiofile.name def add_audios( self, metadata: List[AudioMetadata], audio_ids: List[str], @@ -126,14 +139,11 @@ def add_audios( audio_files = [self.convert_audio_to_wav(audio.audio_bytes) for audio in metadata] - - self.current_batch.extend([ { "audio_id": audio_uuid, "youtube_id": audio.video_id, - # "audio_bytes": audio.audio_bytes, - "audio": {"path": audio_file, "array": sf.read(BytesIO(base64.b64decode(audio.audio_bytes)))[0], "sampling_rate": 16000}, + "audio": {"path": audio_file, "array": sf.read(audio_file)[0], "sampling_rate": 16000}, "start_time": audio.start_time, "end_time": audio.end_time, "audio_embed": audio.audio_emb, @@ -158,10 +168,10 @@ def submit(self) -> None: if len(self.current_batch) < self.min_batch_size: print(f"Need at least {self.min_batch_size} audios to submit, but have {len(self.current_batch)}") return - data = self.current_batch[:self.desired_batch_size] - print(f"Uploading batch of {len(self.current_batch)} audios") + batch_to_upload = self.current_batch[:self.desired_batch_size] + print(f"Uploading batch of {len(batch_to_upload)} audios") with BytesIO() as f: - dataset = Dataset.from_list(data) + dataset = Dataset.from_list(batch_to_upload) dataset = dataset.cast_column("audio", Audio()) num_bytes = dataset.to_parquet(f) try: @@ -175,11 +185,40 @@ def submit(self) -> None: print(f"Uploaded {num_bytes} bytes to Hugging Face") except Exception as e: print(f"Error uploading to Hugging Face: {e}") + + # Clean up temp files after successful upload + self.clean_temp_dir(batch_to_upload) self.current_batch = self.current_batch[self.desired_batch_size:] self.desired_batch_size = get_random_batch_size(config.UPLOAD_AUDIO_BATCH_SIZE) - - - audio_dataset_uploader = AudioDatasetUploader() video_dataset_uploader = DatasetUploader() + + +if __name__ == "__main__": + audio_wav_file = "../example.wav" + with open(audio_wav_file, "rb") as f: + audio_bytes = base64.b64encode(f.read()).decode('utf-8') + audio_dataset_uploader.add_audios( + metadata=[ + AudioMetadata( + video_id="123", + start_time=0, + end_time=10, + audio_bytes=audio_bytes, + audio_emb=[], + views=0, + diar_timestamps_start=[], + diar_timestamps_end=[], + diar_speakers=[], + ) + ] * 10, + audio_ids=list(range(10)), + inverse_der=0.0, + audio_length_score=0.0, + audio_quality_total_score=0.0, + audio_query_score=0.0, + query="", + total_score=0.0, + ) + audio_dataset_uploader.submit()