Logging custom Nemo Metrics using WandB #6974
Unanswered
devansh-shah-11
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am using Nvidia Nemo to train model and plotting metrics in WandB.
I am using WandB sweep to estimate the best hyperparameters.
As the validation loss wasnt dropping after a point, I thought of an idea to create a test manifest json file that comprises of 70% of train and 30% of validation data (total 100 examples in test file) to check if atleast model is overfitting on train data. How to log wer and loss on test data for every epoch?
Am attaching the code below:
import nemo
from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.utils import logging, exp_manager
import pytorch_lightning as pl
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.models import EncDecCTCModel
import wandb
MODEL TRAINING
--- Config Information ---#
def load_config(config_path):
try:
from ruamel.yaml import YAML
except ModuleNotFoundError:
from ruamel_yaml import YAML
yaml = YAML(typ='safe')
with open(config_path, encoding='utf-8') as f:
params = yaml.load(f)
print(params)
from omegaconf import OmegaConf
config = OmegaConf.load(config_path)
return params,config
--- Creating Checkpoint---#
def create_checkpoint(params,dirpath):
--- Calculate WER---#
def calculate_wer(model,params,manifest_path, batch_size=4):
model.setup_test_data(test_data_config=params['validation_ds'])
model.cuda()
--- Transcribing---#
def transcribe(model,audio_path):
return str(model.transcribe(paths2audio_files=[audio_path]))
--- Saving Model---#
def save_model(model, save_path):
import os
model.save_to(f"{save_path}")
print(f"Model saved at path : {os.getcwd() + os.path.sep + save_path}")
def main():
import os
import torch
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from omegaconf import DictConfig
import wandb
import tracemalloc
tracemalloc.start()
run = wandb.init()
config_path = "path_To_config_file"
params,config = load_config(config_path)
checkpoint_callback, wandb_logger = create_checkpoint(params,dirpath = "Checkpoints")
trainer = pl.Trainer(gpus=1, max_epochs=5, callbacks=[checkpoint_callback], logger=wandb_logger)
first_asr_model = nemo_asr.models.EncDecCTCModel(cfg=DictConfig(params), trainer=trainer)
trainer.fit(first_asr_model)
save_model(first_asr_model,os.path.join(wandb.run.dir, "model.nemo"))
#first_asr_model.save(os.path.join(wandb.run.dir, "model.nemo"))
manifest_path = "manifest_path"
val_wer = calculate_wer(first_asr_model, params, manifest_path)
test_path = "path to test manifest.json"
test_wer = calculate_wer(first_asr_model, params, test_path)
wandb.log({'test_wer': test_wer })
sweep_configuration = {
'method': 'bayes',
'name': 'try-2',
'metric': {'goal': 'minimize', 'name': 'val_wer', 'target' : 0.2},
'parameters':
{
'train_ds.batch_size': {'values': [16]},
'optim.weight_decay' : {'values': [0.0005, 0.001, 0.002]},
'optim.lr': {'max': 0.0025, 'min': 0.000025},
'optim.betas': {'values': [[0.8, 0.7], [0.5, 0.6], [0.7, 0.6], [0.9, 0.999]]},
'optim.sched.warmup_ratio': {'values': [0, 0.05, 0.1]}
}
}
sweep_id = wandb.sweep(
sweep=sweep_configuration,
project='ASR_Try2'
)
wandb.agent(sweep_id, function=main, count=3)
try:
wandb.alert(title="Runtime Crashed", text="Check the runtime", level=wandb.AlertLevel.ERROR)
print("alerting that runtime is crashed")
except:
print("\nNo Active runtime found\n")
print("crashed notification 404")
try:
wandb.alert(title='Completed', text='Sweep has been completed')
print("alerting that runtime is completed")
except:
print("\nNo Active runtime found\n")
print("runtime error 404")
Beta Was this translation helpful? Give feedback.
All reactions