Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 57 additions & 46 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,66 +4,77 @@
import os
from pathlib import Path
from safetensors import safe_open

import torch
from inference import Mars5TTS, InferenceConfig

ar_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.pt"
nar_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.pt"

ar_sf_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.safetensors"
nar_sf_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.safetensors"

def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors',
ar_path=None, nar_path=None) -> Mars5TTS:
""" Load mars5 english model on `device`, optionally show `progress`. """
if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu'

assert ckpt_format in ['safetensors', 'pt'], "checkpoint format must be 'safetensors' or 'pt'"

logging.info(f"Using device: {device}")
if pretrained == False: raise AssertionError('Only pretrained model currently supported.')
logging.info("Loading AR checkpoint...")

if ar_path is None:
if ckpt_format == 'safetensors':
ar_ckpt = _load_safetensors_ckpt(ar_sf_url, progress=progress)
elif ckpt_format == 'pt':
ar_ckpt = torch.hub.load_state_dict_from_url(
ar_url, progress=progress, check_hash=False, map_location='cpu'
)
else: ar_ckpt = torch.load(str(ar_path), map_location='cpu')

logging.info("Loading NAR checkpoint...")
if nar_path is None:
if ckpt_format == 'safetensors':
nar_ckpt = _load_safetensors_ckpt(nar_sf_url, progress=progress)
elif ckpt_format == 'pt':
nar_ckpt = torch.hub.load_state_dict_from_url(
nar_url, progress=progress, check_hash=False, map_location='cpu'
)
else: nar_ckpt = torch.load(str(nar_path), map_location='cpu')
logging.info("Initializing modules...")
mars5 = Mars5TTS(ar_ckpt, nar_ckpt, device=device)
return mars5, InferenceConfig

# Centralized checkpoint URLs for easy management and updates
CHECKPOINT_URLS = {
"ar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.pt",
"nar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.pt",
"ar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.safetensors",
"nar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.safetensors"
}

def _load_safetensors_ckpt(url, progress):
""" Loads checkpoint from a safetensors file """
def load_checkpoint(url, progress=True, ckpt_format='pt'):
""" Helper function to download and load a checkpoint, reducing duplication """
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)

if not os.path.exists(cached_file):
# download it
torch.hub.download_url_to_file(url, cached_file, None, progress=progress)
# load checkpoint

if ckpt_format == 'safetensors':
return _load_safetensors_ckpt(cached_file)
else:
return torch.load(cached_file, map_location='cpu')

def _load_safetensors_ckpt(file_path):
""" Loads a safetensors checkpoint file """
ckpt = {}
with safe_open(cached_file, framework='pt', device='cpu') as f:
with safe_open(file_path, framework='pt', device='cpu') as f:
metadata = f.metadata()
ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']}
ckpt['model'] = {}
for k in f.keys(): ckpt['model'][k] = f.get_tensor(k)
ckpt['model'] = {k: f.get_tensor(k) for k in f.keys()}
return ckpt

def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors', ar_path=None, nar_path=None):

# Load Mars5 English model on `device`, optionally showing progress.
# This function also handles user-provided paths for model checkpoints,
# supporting both .pt and .safetensors formats.

if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f"Using device: {device}")

if not pretrained:
raise ValueError('Only pretrained models are currently supported.')

# Determine the format of the checkpoint based on the file extension if paths are provided
if ar_path is not None:
if ar_path.endswith('.pt'):
ar_ckpt = load_checkpoint(None, progress, 'pt', ar_path)
elif ar_path.endswith('.safetensors'):
ar_ckpt = load_checkpoint(None, progress, 'safetensors', ar_path)
else:
raise NotImplementedError("Unsupported file format for ar_path. Please provide a .pt or .safetensors file.")
else:
ar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'ar_{ckpt_format}'], progress, ckpt_format)

if nar_path is not None:
if nar_path.endswith('.pt'):
nar_ckpt = load_checkpoint(None, progress, 'pt', nar_path)
elif nar_path.endswith('.safetensors'):
nar_ckpt = load_checkpoint(None, progress, 'safetensors', nar_path)
else:
raise NotImplementedError("Unsupported file format for nar_path. Please provide a .pt or .safetensors file.")
else:
nar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'nar_{ckpt_format}'], progress, ckpt_format)

logging.info("Initializing models...")
return Mars5TTS(ar_ckpt, nar_ckpt, device=device), InferenceConfig