diff --git a/configs/config.yaml b/configs/config.yaml index 18b6ce5..eaef349 100755 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -14,9 +14,9 @@ model: # paths - llama_path: "path to llama model" - whisper_path: "path to whisper model" - beats_path: "path to beats model" + llama_path: "path to llama model" # huggingface repo_id or local path + whisper_path: "path to whisper model" # huggingface repo_id or local path + beats_path: "path to beats model" # .pt file path downloaded from main git repo ckpt: "model ckpt" # if not "", load model from ckpt for training or evaluation @@ -27,7 +27,11 @@ model: use_speech_Qformer: True freeze_speech_QFormer: False window_level_Qformer: True - num_speech_query_token: 1 + num_speech_query_token: 1 # Q-Former hyper-parameter + # Given audio length L seconds, Number of windows = floor((L - second_per_window) / second_stride) + 1 + # With current values (1/3 second), each second produces 3 non-overlapping windows + # Numeric example: + # - For a 10-second audio: Number of windows = floor((10 - 0.333333) / 0.333333) + 1 = 29 + 1 = 30 windows second_per_window: 0.333333 second_stride: 0.333333 @@ -42,9 +46,9 @@ model: multi_prompt: True prompt_template: "USER: {}\nASSISTANT:" - prompt_path: "prompts/train_prompt.json" - test_prompt_path: "prompts/test_prompt.json" - max_txt_len: 300 + prompt_path: "prompts/train_prompt.json" # local path to train prompt + test_prompt_path: "prompts/test_prompt.json" # local path to test prompt + max_txt_len: 300 # text token limit, beyond which the text will be truncated end_sym: "" datasets: @@ -56,23 +60,24 @@ datasets: run: # log & settings - seed: 42 - output_dir: "output directory" + seed: 42 # random seed + output_dir: "output directory" # directory to save checkpoints & tensorboard logs, etc. evaluate: False # if True, only evaluate model on test data log_freq: 5 - epoch_based: False - iters_per_epoch: 3000 - accum_grad_iters: 1 - batch_size_train: 8 - batch_size_eval: 8 - num_workers: 8 + save_freq: 1000 # save checkpoint every 1000 iterations + epoch_based: False # if True, exhausted all data for each epoch + iters_per_epoch: 3000 # manually set for each epoch, when epoch_based is False + accum_grad_iters: 1 # GBS / (MBS * DP) + batch_size_train: 8 # MBS + batch_size_eval: 8 # MBS + num_workers: 8 # number of dataloader workers device: "cuda" use_distributed: True - amp: True - world_size: 1 - dist_url: "env://" + amp: True # automatic mixed precision by torch.cuda.amp + world_size: 1 # overwritten by actual world_size (n_nodes * n_gpus) + dist_url: "env://" # use EnvVars (Master_ADDR, Master_PORT) for distributed training # optimizer & scheduler optims: diff --git a/dist_utils.py b/dist_utils.py index c293fa2..58aaacb 100755 --- a/dist_utils.py +++ b/dist_utils.py @@ -55,10 +55,12 @@ def is_main_process(): def init_distributed_mode(args): + # if using torchrun if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) + # if using srun in SLURM elif "SLURM_PROCID" in os.environ: args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() diff --git a/models/README.md b/models/README.md new file mode 100644 index 0000000..8b1dac2 --- /dev/null +++ b/models/README.md @@ -0,0 +1,226 @@ +# SALMONN Models + +This directory contains the implementation of various models used in the SALMONN (Speech Audio Language Music Open Neural Network) project, including Whisper, and BEATs, Qformer, LLaMA + +## 1. Whisper Architecture + +Whisper is a speech processing model used in SALMONN for encoding speech inputs. It's a Transformer-based encoder-decoder model originally designed for automatic speech recognition (ASR) and translation, but in SALMONN, primarily the encoder component is utilized to extract rich representations from speech. + +### 1.1 Class Diagram + +#### Class Inheritance Hierarchy +``` +PreTrainedModel +│ +└──> WhisperPreTrainedModel + │ + └──> WhisperModel + │ + └──> WhisperForConditionalGeneration +``` + +#### Component Composition Structure +``` +WhisperModel +│ +└──> WhisperEncoder + │ + ├──> WhisperEncoderLayer (multiple layers) + │ │ + │ ├──> WhisperAttention + │ │ │ + │ │ ├──> Linear projections (q_proj, k_proj, v_proj) + │ │ └──> Linear projection (out_proj) + │ │ + │ ├──> LayerNorm (self_attn_layer_norm) + │ │ + │ ├──> Feed-Forward Network + │ │ │ + │ │ ├──> Linear (fc1) + │ │ ├──> Activation function (GELU) + │ │ └──> Linear (fc2) + │ │ + │ └──> LayerNorm (final_layer_norm) + │ + └──> LayerNorm (layer_norm) +``` + +## 2. BEATs Architecture + +BEATs (Bidirectional Encoder representation from Audio Transformers) is an audio processing model used in SALMONN for encoding general audio inputs. It's a transformer-based model designed to extract rich representations from audio spectrograms, complementing Whisper's speech-focused features. + +### 2.1 Class Diagram + +#### Class Inheritance Hierarchy +``` +nn.Module +│ +└──> BEATs +``` + +#### Component Composition Structure +``` +BEATs +│ +├──> nn.Conv2d (patch_embedding) +│ +├──> nn.Linear (post_extract_proj) +│ +├──> LayerNorm (layer_norm) +│ +└──> TransformerEncoder + │ + └──> ModuleList of TransformerEncoderLayer + │ + ├──> MultiheadAttention + │ │ + │ └──> Linear projections (q_proj, k_proj, v_proj, out_proj) + │ + ├──> LayerNorm (self_attn_layer_norm) + │ + ├──> FeedForward + │ │ + │ ├──> Linear (fc1) + │ ├──> Activation function (GELU) + │ ├──> Dropout + │ └──> Linear (fc2) + │ + └──> LayerNorm (final_layer_norm) +``` +### 2.2 Use in SALMONN + +In the SALMONN architecture, BEATs serves as a secondary audio encoder that: +1. Processes raw audio waveforms into audio features that complement Whisper's speech features +2. Provides additional context for non-speech audio elements like music, environmental sounds, and acoustic events + +These features are then combined with Whisper features before being processed by the Qformer, enhancing SALMONN's ability to understand the full audio context. + +In SALMONN, BEATs processes audio inputs in parallel with Whisper: +```python +beats_embeds = self.audio_encoder(waveform, return_dict=True)[0] +``` + +The BEATs features can then be combined with Whisper features to provide a more comprehensive audio understanding: +```python +if self.use_audio_Qformer: + speech_embeds = torch.cat([speech_embeds, beats_embeds], dim=-1) + speech_embeds = self.speech_beats_proj(speech_embeds) +``` + +This multi-encoder approach allows SALMONN to leverage both speech-specific features from Whisper and general audio features from BEATs, creating a more robust representation of the audio input. + + +## 3. Qformer Architecture + +The Qformer (Query Transformer) is a key component in SALMONN that serves as a modality adapter between speech/audio inputs and the language model. It's based on a modified BERT architecture and is designed to efficiently extract and transform features from one modality to be compatible with another. + +Qformer uses learnable query tokens that interact with input features through cross-attention mechanisms. These query tokens act as information bottlenecks that extract the most relevant information from the input modality. The number of query tokens is typically much smaller than the input sequence length, allowing for efficient information extraction and dimensionality reduction. + +### 3.1 Class Diagram + +#### Class Inheritance Hierarchy +``` +PreTrainedModel +│ +└──> BertPreTrainedModel + ├──> BertModel + ├──> BertLMHeadModel (Qformer) + └──> BertForMaskedLM +``` + +#### Component Composition Structure +``` +Composition Structure (has-a relationships): + +BertLMHeadModel +│ +├──> BertModel +│ │ +│ ├──> BertEmbeddings +│ │ +│ ├──> BertEncoder +│ │ │ +│ │ └──> a stack of BertLayer, each is a transformer block containing +│ │ │ +│ │ ├──> BertAttention +│ │ │ │ +│ │ │ └──> BertSelfAttention (can be made into both self/cross-attention) +│ │ │ +│ │ ├──> BertIntermediate +│ │ │ +│ │ └──> BertOutput +│ │ +│ └──> BertPooler +│ +└──> ClassificationHead (cls) +``` + +### 3.2 Use in SALMONN + +The Qformer architecture consists of: + +- **Query Tokens**: Learnable parameters that serve as the interface between modalities +- **Self-Attention Layers**: Standard BERT-style self-attention for processing query tokens +- **Cross-Attention Layers**: Allow query tokens to attend to input features from another modality + +The rest of BERT architecture is not used in Qformer. +- **Word/Position Embeddings**: Removed +- **Layer Outputs/Intermediates (FFN)**: Removed +- **BertPooler**: Not used +- **CLS head**: Removed + +The Qformer can operate at different levels: +- **Window-level**: Processing fixed-length windows of speech/audio features +- **Sequence-level**: Processing the entire sequence at once + +## 4. LLaMA Architecture + +LLaMA (Large Language Model Meta AI) is the foundation language model used in SALMONN for text generation. It's a decoder-only transformer architecture optimized for efficient inference and strong language understanding capabilities. + +### 4.1 Class Diagram + +#### Class Inheritance Hierarchy +``` +PreTrainedModel +│ +└──> LlamaPreTrainedModel + │ + └──> LlamaForCausalLM +``` + +#### Component Composition Structure +``` +LlamaForCausalLM +│ +├──> LlamaModel +│ │ +│ ├──> nn.Embedding (embed_tokens) +│ │ +│ ├──> nn.ModuleList of LlamaDecoderLayer +│ │ │ +│ │ ├──> LlamaAttention +│ │ │ │ +│ │ │ ├──> Linear projections (q_proj, k_proj, v_proj, o_proj) +│ │ │ └──> LlamaRotaryEmbedding (rotary_emb) +│ │ │ +│ │ ├──> LlamaRMSNorm (input_layernorm) +│ │ │ +│ │ ├──> LlamaMLP +│ │ │ │ +│ │ │ ├──> Linear (gate_proj) +│ │ │ ├──> Linear (up_proj) +│ │ │ ├──> ACT2FN[hidden_act] activation +│ │ │ └──> Linear (down_proj) +│ │ │ +│ │ └──> LlamaRMSNorm (post_attention_layernorm) +│ │ +│ └──> LlamaRMSNorm (norm) +│ +└──> nn.Linear (lm_head) +``` + +In SALMONN, LLaMA is integrated with speech features through a projection layer that maps Qformer outputs to the LLaMA embedding space. The model processes inputs in this sequence: + +1. BOS token embeddings +2. Projected speech embeddings from Qformer +3. Text token embeddings (for training) diff --git a/models/salmonn.py b/models/salmonn.py index e11062e..8c9faad 100755 --- a/models/salmonn.py +++ b/models/salmonn.py @@ -16,11 +16,12 @@ import json import contextlib import random +from typing import Optional, Union, List import torch import torch.nn as nn import torch.nn.functional as F -from transformers import LlamaTokenizer, StoppingCriteriaList +from transformers import LlamaTokenizer, StoppingCriteriaList, AutoTokenizer from peft import LoraConfig, TaskType, get_peft_model from .Qformer import BertConfig, BertLMHeadModel @@ -129,6 +130,7 @@ def __init__( param.requires_grad = False logging.info('Loading LLaMA Done') + # randomly initialize LORA parameters if self.lora: self.peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, @@ -141,6 +143,7 @@ def __init__( self.llama_model.print_trainable_parameters() logging.info('LoRA Training') + # loading Whisper model from huggingface (remote/local) assert whisper_path logging.info('Loading Whisper Model') self.speech_encoder = WhisperModel.from_pretrained(whisper_path).encoder @@ -151,6 +154,8 @@ def __init__( self.speech_encoder.eval() logging.info("freeze Whisper") + # loading BEATs from local .pt into CPU RAM + # BEATs model is optional, only used for audio feature extraction if self.beats_path: logging.info("Loading BEATs Model") beats_ckpt = torch.load(self.beats_path, map_location='cpu') @@ -164,15 +169,26 @@ def __init__( self.beats.eval() logging.info("freeze BEATs") + # initialize speech QFormer if self.use_speech_Qformer: + # initialize the QFormer + # if BEATs is used, speech and audio features are concatenated along hidden_size dimension if self.beats_path: self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer( num_query_token=num_speech_query_token, speech_width=self.speech_encoder.config.d_model + self.beats.cfg.encoder_embed_dim ) + # if BEATs is not used else: self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer( num_query_token=num_speech_query_token, speech_width=self.speech_encoder.config.d_model ) + # Modify the QFormer by selectively removing its components for using it as feature extractor + # in the modality adapter + # Removed: word/position embeddings, layer outputs/intermediates (FFN), CLS head + # Retained: cross-attention layers, Q tokens, self-attention + with open("Qformer.log", "w") as file: + for name, module in self.speech_Qformer.bert.named_modules(): + file.write(f"{name}: {module}\n") self.speech_Qformer.bert.embeddings.word_embeddings = None self.speech_Qformer.bert.embeddings.position_embeddings = None for layer in self.speech_Qformer.bert.encoder.layer: @@ -186,6 +202,8 @@ def __init__( self.speech_query_tokens.requires_grad = False logging.info("freeze Speech QFormer") + # add linear head for modality adapter, from n_channels of QFormer Feature Extractor + # to hidden_size of LLM logging.info('Loading speech LLAMA proj') self.speech_llama_proj = nn.Linear( self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size @@ -203,7 +221,8 @@ def __init__( # feel free to add other aligners here raise NotImplementedError - # prepare prompts + # load prompt templates from json file in prompt_path + # for each task defined in the json file, load into self.prmompt_dcit as {task:prompt_template} self.prompt_dict = {} if prompt_path: try: @@ -221,40 +240,76 @@ def _encode_auditory_feature(self, speech_embeds, audio_embeds=None): if self.use_speech_Qformer: speech_embeds = self.ln_speech(speech_embeds) if audio_embeds is not None: + # pad audio and speech embeddings to equal seq_len T -> (B, T, C1), (B, T, C2) + # and concat them along the channel dimension to get speech_embeds + # X: (B, T, C1+C2) audio_embeds = self.ln_audio(audio_embeds) if audio_embeds.size(1) < speech_embeds.size(1): audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) elif audio_embeds.size(1) > speech_embeds.size(1): speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) + # attention mask (B, T) speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(speech_embeds.device) + print("0 - Shape of speech_embs (B, T, C): ", speech_embeds.shape) # [1, 1500, 2048] + # Default Case: window if self.window_level_Qformer: B, T, C = speech_embeds.shape kernel = round(1500 * self.second_per_window / 30.0) stride = round(1500 * self.second_stride / 30.0) kernel = (1, kernel) stride = (1, stride) + # X: (B, T, C) -> (B, C, 1, T) speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2) + # X: (B, C, 1, T) -> (B, C*K, L) speech_embeds_overlap = F.unfold(speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride) + # L: number of windows _, _, L = speech_embeds_overlap.shape + # X: (B, C*K, L) → (B, C, K, L) speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L) + # X: (B, C, K, L) -> (B, L, K, C) speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1]) + # X: (B*L, K, C), e.g. [88, 17, 2048] speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C) + # M: (B*L, K), all ones (no masking) speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device) + # no_window: Q: (1, Q, C) -> (B, Q, C) + # window: Q: (1, Q, C) e.g. [1, 1, 768] -> (B*L, Q, C) query_tokens = self.speech_query_tokens.expand(speech_embeds.shape[0], -1, -1) + + # no_window: + # K: (B, T, C) = X (B, C, T) * W_K (C, C) + # V: (B, T, C) = X (B, C, T) * W_V (C, C) + # attn_score: (B, Q, T) = Q (B, Q, C) · Kᵀ (B, C, T) + # attn_prob: (B, Q, T) = softmax(attn_score, dim = -1) + # query_output: (B, Q, C) = attn_prob (B, Q, T) · V (B, T, C), e.g. [1, 1, 768] + # _________________________________________________________ + + # window: + # K: (B*L, K, C) = X * W_K + # V: (B*L, K, C) = X * W_V + # attn_score: (B*L, Q, K) = Q (B*L, Q, C) ⋅ Kᵀ (B*L, C, K) + # attn_prob: (B*L, Q, K) = softmax(attn_score, dim = -1) + # query_output: (B*L, Q, C) = attn_prob (B*L, Q, K) ⋅ V (B*L, K, C), e.g. [88, 1, 768] + # _________________________________________________________ query_output = self.speech_Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=speech_embeds, encoder_attention_mask=speech_atts, return_dict=True, ) - speech_embeds = self.speech_llama_proj(query_output.last_hidden_state) + # no_window, X: (B, Q, C) -> (B, Q, d), e.g. [1, 1, 5120] + # _________________________________________________________ + + # window, X: (B*L, Q, C) -> (B*L, Q, d), e.g. [88, 1, 5120] + speech_embeds = self.speech_llama_proj(query_output.last_hidden_state) if self.window_level_Qformer: + # X: (B*L, Q, d) → (B, L*Q, d), e.g. [1, 88, 5120] speech_embeds = speech_embeds.view(B, -1, speech_embeds.size(2)).contiguous() - + # M: (B, L*Q), all ones (no masking) speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(speech_embeds.device) else: raise NotImplementedError @@ -262,6 +317,18 @@ def _encode_auditory_feature(self, speech_embeds, audio_embeds=None): return speech_embeds, speech_atts def encode_speech(self, spectrogram, raw_wav=None, audio_padding_mask=None): + """ + Encodes spectrogram and optional raw audio into a unified feature representation. + + Args: + spectrogram (B, T_spec, d_spec): Input spectrogram features. + raw_wav (B, T_audio): Raw audio waveform for BEATs encoding. + audio_padding_mask (B, T_audio): Padding mask for raw audio. + + Returns: + speech_embeds (B, L*Q, d): Encoded speech/audio features. + speech_atts (B, L*Q): Attention mask for encoded features. + """ with self.maybe_autocast(): speech_embeds = self.speech_encoder(spectrogram, return_dict=True).last_hidden_state @@ -273,6 +340,19 @@ def encode_speech(self, spectrogram, raw_wav=None, audio_padding_mask=None): return self._encode_auditory_feature(speech_embeds, audio_embeds=audio_embeds) def prompt_wrap(self, embeds, atts, prompt, multi_prompt=False): + """ + Wraps speech/audio embeddings with prompt embeddings (masked) + + Args: + embeds (B, L*Q, d): Speech/audio embeddings. + atts (B, L*Q): Attention mask for speech/audio embeddings. + prompt (str or list): Prompt(s) containing "" marker. + multi_prompt (bool): Whether to use multiple prompts (one per batch item). + + Returns: + wrapped_embeds (B, T, d): Concatenated prompt and speech/audio embeddings. + wrapped_atts (B, T): Concatenated attention masks. + """ if prompt: if multi_prompt: p_before = [] @@ -315,6 +395,25 @@ def prompt_wrap(self, embeds, atts, prompt, multi_prompt=False): return embeds, atts def forward(self, samples, verbose=False): + """ + Forward pass for the SALMONN model. + Concat the embeddings of BOS, speech, and text + Concat the attenion mask and target of BOS (masked), speech (masked), and text (non-masked except padded) + Feed the combined embeddings into the LLM, get the loss and other metrics + + Args: + samples (dict): A dictionary containing input samples with keys: + - "spectrogram": Spectrogram input (B, T_spec, D_spec) + - "raw_wav" (optional): Raw audio waveform (B, T_audio) + - "padding_mask" (optional): Mask for padded audio regions (B, T_audio) + - "task": List of task identifiers for each sample + - "Q" (optional): Questions for QA tasks + - "text": Target text to generate + verbose (bool, optional): Whether to print verbose information. Defaults to False. + + Returns: + dict: Model outputs containing loss and other metrics + """ # detect whether there are multi tasks in this batch task = list(set(samples["task"])) if len(task) > 1 or "QA" in task: @@ -329,18 +428,21 @@ def forward(self, samples, verbose=False): else: prompt = random.choice(self.prompt_dict[samples["task"][0]]) - # use speech/audio encoder to encode speech/audio - spectrogram = samples["spectrogram"] - raw_wav = samples.get("raw_wav", None) - audio_padding_mask = samples.get("padding_mask", None) + # Extract inputs from samples + spectrogram = samples["spectrogram"] # (B, T_spec, D_spec) + raw_wav = samples.get("raw_wav", None) # (B, T_audio) + audio_padding_mask = samples.get("padding_mask", None) # (B, T_audio) + # Encode speech/audio into embeddings, speech_embeds: (B, L*Q, d), speech_atts: (B, L*Q) speech_embeds, speech_atts = self.encode_speech(spectrogram, raw_wav=raw_wav, audio_padding_mask=audio_padding_mask) # wrap speech_embeds with prompts + # speech_embeds: (B, T, d), speech_atts: (B, T), T = L*Q + T_prompt if self.prompt_dict: speech_embeds, speech_atts = self.prompt_wrap(speech_embeds, speech_atts, prompt, multi_prompt=self.multi_prompt) # prepare inputs for LLM + # join the list of str in samples["text"] into a single str, separate with end_sym text = [t + self.end_sym for t in samples["text"]] to_regress_tokens = self.llama_tokenizer( text, @@ -350,10 +452,14 @@ def forward(self, samples, verbose=False): max_length=self.max_txt_len, add_special_tokens=False ).to(spectrogram.device) + # text embeddings: (B, T_text, d) to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) if not self.lora else self.llama_model.model.model.embed_tokens(to_regress_tokens.input_ids) + + # mask for text tokens, original token_ids except the padded ones are masked with -100: (B, T_text) targets = to_regress_tokens.input_ids.masked_fill( to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 ) + # mask for BOS+speech tokens, all masked with -100: (B, T_speech + 1) empty_targets = ( torch.ones( [speech_atts.shape[0], speech_atts.shape[1] + 1], @@ -363,15 +469,20 @@ def forward(self, samples, verbose=False): targets = torch.cat([empty_targets, targets], dim=1) batch_size = speech_embeds.shape[0] + # BOS token: (B, 1) bos = torch.ones( [batch_size, 1], dtype=to_regress_tokens.input_ids.dtype, device=to_regress_tokens.input_ids.device, ) * self.llama_tokenizer.bos_token_id + # BOS token embeddings: (B, 1, d) bos_embeds = self.llama_model.model.embed_tokens(bos) if not self.lora else self.llama_model.model.model.embed_tokens(bos) + # attention mask for BOS token: (B, 1) atts_bos = speech_atts[:, :1] + # combined inputs_embeds: (B, 1 + T_speech + T_txt, d) inputs_embeds = torch.cat([bos_embeds, speech_embeds, to_regress_embeds], dim=1) + # combined attention_mask: (B, 1 + T_speech + T_txt) attention_mask = torch.cat([atts_bos, speech_atts, to_regress_tokens.attention_mask], dim=1) # calulate loss @@ -397,7 +508,23 @@ def forward(self, samples, verbose=False): return {"loss": loss} - def generate(self, samples, generate_cfg, prompts=None): + def generate( + self, + samples: dict, + generate_cfg: dict, + prompts: Optional[Union[str, List[str]]] = None + ) -> List[str]: + """ + Generates text output from speech/audio input using the model. + + Args: + samples: Input samples with spectrogram (B, T_spec, d_spec), optional raw_wav (B, T_audio) + generate_cfg: Generation parameters (max_new_tokens, num_beams, temperature, etc.) + prompts: Optional prompt template(s) for generation + + Returns: + List of generated text outputs, one per batch item + """ batch_size = samples["spectrogram"].shape[0] spectrogram = samples["spectrogram"] @@ -441,6 +568,7 @@ def generate(self, samples, generate_cfg, prompts=None): @classmethod def from_config(cls, config): + """Creates a SALMONN model instance from a configuration dictionary.""" llama_path = config.get("llama_path") whisper_path = config.get("whisper_path") freeze_whisper = config.get("freeze_whisper", True) diff --git a/runner.py b/runner.py index 0ad16fc..bd54bfd 100755 --- a/runner.py +++ b/runner.py @@ -96,11 +96,26 @@ def unwrap_dist_model(self, model): return model def train_epoch(self, epoch): + """ + Trains the model for one epoch. + + Performs forward and backward passes, gradient updates, and tracks metrics. + Handles gradient accumulation and mixed precision training if enabled. + + Args: + epoch (int): Current epoch number + + Returns: + dict: Dictionary of averaged training metrics for the epoch + """ + # Set model to training mode self.model.train() + # Initialize metric tracking metric_logger = MetricLogger(delimiter=" ") metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}")) + # TODO: add meter for TFLOPs, samples per second, GBS, vRAM usage, etc. logging.info( "Start training epoch {}, {} iters per inner epoch.".format( @@ -109,23 +124,30 @@ def train_epoch(self, epoch): ) header = "Train: data epoch: [{}]".format(epoch) + # Main training loop for i in metric_logger.log_every(range(self.iters_per_epoch), self.config.config.run.log_freq, header=header, logger=self.log_writter, start_step=epoch*self.iters_per_epoch): if i >= self.iters_per_epoch: break + # Get batch and move to device samples = next(self.train_loader) samples = prepare_sample(samples, cuda_enabled=self.cuda_enabled) + # Update learning rate scheduler self.scheduler.step(cur_epoch=epoch, cur_step=i) + # Forward pass with optional mixed precision with torch.cuda.amp.autocast(enabled=self.use_amp): loss = self.model(samples)["loss"] + # Backward pass with optional mixed precision scaling if self.use_amp: self.scaler.scale(loss).backward() else: loss.backward() + # Update weights if gradient accumulation condition is met + # accum_grad_iters = GBS / (MBS * DP) if (i + 1) % self.config.config.run.accum_grad_iters == 0: if self.use_amp: self.scaler.step(self.optimizer) @@ -134,11 +156,15 @@ def train_epoch(self, epoch): self.optimizer.step() self.optimizer.zero_grad() + # Update metrics metric_logger.update(loss=loss.item()) metric_logger.update(lr=self.optimizer.param_groups[0]["lr"]) + # Synchronize metrics across processes in distributed training metric_logger.synchronize_between_processes() logging.info("Averaged stats: " + str(metric_logger.global_avg())) + + # Return formatted metrics return { k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items() @@ -225,20 +251,36 @@ def valid_epoch(self, epoch, split, decode=False, save_json=False): return ret def save_result(self, result, result_dir, filename): + """ + Saves evaluation results to JSON files, handling distributed processing. + + In distributed mode, each process saves its partial results to a rank-specific file. + The main process then aggregates all partial results into a single final file. + + The function handles UTF-8 encoding issues that may arise with non-ASCII text. + + Args: + result (list): List of result dictionaries to save + result_dir (str): Directory to save results + filename (str): Base filename for the result files + """ result_file = os.path.join( result_dir, "%s_rank%d.json" % (filename, get_rank()) ) final_result_file = os.path.join(result_dir, "%s.json" % filename) + # Save this process's results to rank-specific file try: json.dump(result, open(result_file, "w"), ensure_ascii=False) except Exception as e: logging.warning(f"Error saving {result_file}. Error: {e}") json.dump(result, open(result_file, "w", encoding="utf-8"), ensure_ascii=False) + # Synchronize all processes before merging if is_dist_avail_and_initialized(): dist.barrier() + # Main process aggregates results from all ranks if is_main_process(): logging.info("rank %d starts merging results." % get_rank()) result = [] @@ -254,6 +296,7 @@ def save_result(self, result, result_dir, filename): res = json.load(open(result_file, "r", encoding="utf-8")) result += res + # Save merged results to final file try: json.dump(result, open(final_result_file, "w"), ensure_ascii=False) except Exception as e: @@ -279,6 +322,8 @@ def train(self): # validating phase logging.info("Validating Phase") valid_log = self.valid_epoch(cur_epoch, "valid", decode=False, save_json=False) + # if online validation is used, will also save a copy of the current best checkpoint + # this current best checkpoint is always updated as training goes. if valid_log is not None: if is_main_process(): agg_metrics = valid_log["agg_metrics"] @@ -290,7 +335,7 @@ def train(self): valid_log.update({"best_epoch": best_epoch}) self.log_stats(valid_log, split_name="valid") - + # always save the regular checkpoint at the end of each epoch self.save_checkpoint(cur_epoch, is_best=False) if self.use_distributed: