Skip to content

OOM fix #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions validator-api/validator_api/dataset_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
import random
import tempfile
import os

from datasets import Dataset, Audio
from huggingface_hub import HfApi
Expand Down Expand Up @@ -101,20 +102,32 @@ def submit(self) -> None:
self.current_batch = self.current_batch[self.desired_batch_size:]
self.desired_batch_size = get_random_batch_size(config.UPLOAD_BATCH_SIZE)



class AudioDatasetUploader:
def __init__(self):
self.current_batch = []
self.min_batch_size = 8
self.desired_batch_size = get_random_batch_size(config.UPLOAD_AUDIO_BATCH_SIZE)
self.temp_dir = self.make_temp_dir()

def make_temp_dir(self):
return tempfile.mkdtemp()

def convert_audio_to_wav(self, audio_bytes: str) -> bytes:
temp_audiofile = tempfile.NamedTemporaryFile(suffix=".wav")
def clean_temp_dir(self, batch_to_clean: List[dict]):
if self.temp_dir:
for item in batch_to_clean:
try:
os.unlink(item["audio"]["path"])
except:
pass

def convert_audio_to_wav(self, audio_bytes: str) -> str:
temp_audiofile = tempfile.NamedTemporaryFile(suffix=".wav", dir=self.temp_dir, delete=False)
audio_bytes = base64.b64decode(audio_bytes)
with open(temp_audiofile.name, "wb") as f:
f.write(audio_bytes)
return temp_audiofile.read()

return temp_audiofile.name

def add_audios(
self, metadata: List[AudioMetadata], audio_ids: List[str],
Expand All @@ -126,14 +139,11 @@ def add_audios(

audio_files = [self.convert_audio_to_wav(audio.audio_bytes) for audio in metadata]



self.current_batch.extend([
{
"audio_id": audio_uuid,
"youtube_id": audio.video_id,
# "audio_bytes": audio.audio_bytes,
"audio": {"path": audio_file, "array": sf.read(BytesIO(base64.b64decode(audio.audio_bytes)))[0], "sampling_rate": 16000},
"audio": {"path": audio_file, "array": sf.read(audio_file)[0], "sampling_rate": 16000},
"start_time": audio.start_time,
"end_time": audio.end_time,
"audio_embed": audio.audio_emb,
Expand All @@ -158,10 +168,10 @@ def submit(self) -> None:
if len(self.current_batch) < self.min_batch_size:
print(f"Need at least {self.min_batch_size} audios to submit, but have {len(self.current_batch)}")
return
data = self.current_batch[:self.desired_batch_size]
print(f"Uploading batch of {len(self.current_batch)} audios")
batch_to_upload = self.current_batch[:self.desired_batch_size]
print(f"Uploading batch of {len(batch_to_upload)} audios")
with BytesIO() as f:
dataset = Dataset.from_list(data)
dataset = Dataset.from_list(batch_to_upload)
dataset = dataset.cast_column("audio", Audio())
num_bytes = dataset.to_parquet(f)
try:
Expand All @@ -175,11 +185,40 @@ def submit(self) -> None:
print(f"Uploaded {num_bytes} bytes to Hugging Face")
except Exception as e:
print(f"Error uploading to Hugging Face: {e}")

# Clean up temp files after successful upload
self.clean_temp_dir(batch_to_upload)
self.current_batch = self.current_batch[self.desired_batch_size:]
self.desired_batch_size = get_random_batch_size(config.UPLOAD_AUDIO_BATCH_SIZE)




audio_dataset_uploader = AudioDatasetUploader()
video_dataset_uploader = DatasetUploader()


if __name__ == "__main__":
audio_wav_file = "../example.wav"
with open(audio_wav_file, "rb") as f:
audio_bytes = base64.b64encode(f.read()).decode('utf-8')
audio_dataset_uploader.add_audios(
metadata=[
AudioMetadata(
video_id="123",
start_time=0,
end_time=10,
audio_bytes=audio_bytes,
audio_emb=[],
views=0,
diar_timestamps_start=[],
diar_timestamps_end=[],
diar_speakers=[],
)
] * 10,
audio_ids=list(range(10)),
inverse_der=0.0,
audio_length_score=0.0,
audio_quality_total_score=0.0,
audio_query_score=0.0,
query="",
total_score=0.0,
)
audio_dataset_uploader.submit()