Skip to content

added Documentations #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

model:
# paths
llama_path: "path to llama model"
whisper_path: "path to whisper model"
beats_path: "path to beats model"
llama_path: "path to llama model" # huggingface repo_id or local path
whisper_path: "path to whisper model" # huggingface repo_id or local path
beats_path: "path to beats model" # .pt file path downloaded from main git repo

ckpt: "model ckpt" # if not "", load model from ckpt for training or evaluation

Expand All @@ -27,7 +27,11 @@ model:
use_speech_Qformer: True
freeze_speech_QFormer: False
window_level_Qformer: True
num_speech_query_token: 1
num_speech_query_token: 1 # Q-Former hyper-parameter
# Given audio length L seconds, Number of windows = floor((L - second_per_window) / second_stride) + 1
# With current values (1/3 second), each second produces 3 non-overlapping windows
# Numeric example:
# - For a 10-second audio: Number of windows = floor((10 - 0.333333) / 0.333333) + 1 = 29 + 1 = 30 windows
second_per_window: 0.333333
second_stride: 0.333333

Expand All @@ -42,9 +46,9 @@ model:

multi_prompt: True
prompt_template: "USER: {}\nASSISTANT:"
prompt_path: "prompts/train_prompt.json"
test_prompt_path: "prompts/test_prompt.json"
max_txt_len: 300
prompt_path: "prompts/train_prompt.json" # local path to train prompt
test_prompt_path: "prompts/test_prompt.json" # local path to test prompt
max_txt_len: 300 # text token limit, beyond which the text will be truncated
end_sym: "</s>"

datasets:
Expand All @@ -56,23 +60,24 @@ datasets:

run:
# log & settings
seed: 42
output_dir: "output directory"
seed: 42 # random seed
output_dir: "output directory" # directory to save checkpoints & tensorboard logs, etc.
evaluate: False # if True, only evaluate model on test data

log_freq: 5
epoch_based: False
iters_per_epoch: 3000
accum_grad_iters: 1
batch_size_train: 8
batch_size_eval: 8
num_workers: 8
save_freq: 1000 # save checkpoint every 1000 iterations
epoch_based: False # if True, exhausted all data for each epoch
iters_per_epoch: 3000 # manually set for each epoch, when epoch_based is False
accum_grad_iters: 1 # GBS / (MBS * DP)
batch_size_train: 8 # MBS
batch_size_eval: 8 # MBS
num_workers: 8 # number of dataloader workers

device: "cuda"
use_distributed: True
amp: True
world_size: 1
dist_url: "env://"
amp: True # automatic mixed precision by torch.cuda.amp
world_size: 1 # overwritten by actual world_size (n_nodes * n_gpus)
dist_url: "env://" # use EnvVars (Master_ADDR, Master_PORT) for distributed training

# optimizer & scheduler
optims:
Expand Down
2 changes: 2 additions & 0 deletions dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ def is_main_process():


def init_distributed_mode(args):
# if using torchrun
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
# if using srun in SLURM
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
Expand Down
226 changes: 226 additions & 0 deletions models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# SALMONN Models

This directory contains the implementation of various models used in the SALMONN (Speech Audio Language Music Open Neural Network) project, including Whisper, and BEATs, Qformer, LLaMA

## 1. Whisper Architecture

Whisper is a speech processing model used in SALMONN for encoding speech inputs. It's a Transformer-based encoder-decoder model originally designed for automatic speech recognition (ASR) and translation, but in SALMONN, primarily the encoder component is utilized to extract rich representations from speech.

### 1.1 Class Diagram

#### Class Inheritance Hierarchy
```
PreTrainedModel
└──> WhisperPreTrainedModel
└──> WhisperModel
└──> WhisperForConditionalGeneration
```

#### Component Composition Structure
```
WhisperModel
└──> WhisperEncoder
├──> WhisperEncoderLayer (multiple layers)
│ │
│ ├──> WhisperAttention
│ │ │
│ │ ├──> Linear projections (q_proj, k_proj, v_proj)
│ │ └──> Linear projection (out_proj)
│ │
│ ├──> LayerNorm (self_attn_layer_norm)
│ │
│ ├──> Feed-Forward Network
│ │ │
│ │ ├──> Linear (fc1)
│ │ ├──> Activation function (GELU)
│ │ └──> Linear (fc2)
│ │
│ └──> LayerNorm (final_layer_norm)
└──> LayerNorm (layer_norm)
```

## 2. BEATs Architecture

BEATs (Bidirectional Encoder representation from Audio Transformers) is an audio processing model used in SALMONN for encoding general audio inputs. It's a transformer-based model designed to extract rich representations from audio spectrograms, complementing Whisper's speech-focused features.

### 2.1 Class Diagram

#### Class Inheritance Hierarchy
```
nn.Module
└──> BEATs
```

#### Component Composition Structure
```
BEATs
├──> nn.Conv2d (patch_embedding)
├──> nn.Linear (post_extract_proj)
├──> LayerNorm (layer_norm)
└──> TransformerEncoder
└──> ModuleList of TransformerEncoderLayer
├──> MultiheadAttention
│ │
│ └──> Linear projections (q_proj, k_proj, v_proj, out_proj)
├──> LayerNorm (self_attn_layer_norm)
├──> FeedForward
│ │
│ ├──> Linear (fc1)
│ ├──> Activation function (GELU)
│ ├──> Dropout
│ └──> Linear (fc2)
└──> LayerNorm (final_layer_norm)
```
### 2.2 Use in SALMONN

In the SALMONN architecture, BEATs serves as a secondary audio encoder that:
1. Processes raw audio waveforms into audio features that complement Whisper's speech features
2. Provides additional context for non-speech audio elements like music, environmental sounds, and acoustic events

These features are then combined with Whisper features before being processed by the Qformer, enhancing SALMONN's ability to understand the full audio context.

In SALMONN, BEATs processes audio inputs in parallel with Whisper:
```python
beats_embeds = self.audio_encoder(waveform, return_dict=True)[0]
```

The BEATs features can then be combined with Whisper features to provide a more comprehensive audio understanding:
```python
if self.use_audio_Qformer:
speech_embeds = torch.cat([speech_embeds, beats_embeds], dim=-1)
speech_embeds = self.speech_beats_proj(speech_embeds)
```

This multi-encoder approach allows SALMONN to leverage both speech-specific features from Whisper and general audio features from BEATs, creating a more robust representation of the audio input.


## 3. Qformer Architecture

The Qformer (Query Transformer) is a key component in SALMONN that serves as a modality adapter between speech/audio inputs and the language model. It's based on a modified BERT architecture and is designed to efficiently extract and transform features from one modality to be compatible with another.

Qformer uses learnable query tokens that interact with input features through cross-attention mechanisms. These query tokens act as information bottlenecks that extract the most relevant information from the input modality. The number of query tokens is typically much smaller than the input sequence length, allowing for efficient information extraction and dimensionality reduction.

### 3.1 Class Diagram

#### Class Inheritance Hierarchy
```
PreTrainedModel
└──> BertPreTrainedModel
├──> BertModel
├──> BertLMHeadModel (Qformer)
└──> BertForMaskedLM
```

#### Component Composition Structure
```
Composition Structure (has-a relationships):

BertLMHeadModel
├──> BertModel
│ │
│ ├──> BertEmbeddings
│ │
│ ├──> BertEncoder
│ │ │
│ │ └──> a stack of BertLayer, each is a transformer block containing
│ │ │
│ │ ├──> BertAttention
│ │ │ │
│ │ │ └──> BertSelfAttention (can be made into both self/cross-attention)
│ │ │
│ │ ├──> BertIntermediate
│ │ │
│ │ └──> BertOutput
│ │
│ └──> BertPooler
└──> ClassificationHead (cls)
```

### 3.2 Use in SALMONN

The Qformer architecture consists of:

- **Query Tokens**: Learnable parameters that serve as the interface between modalities
- **Self-Attention Layers**: Standard BERT-style self-attention for processing query tokens
- **Cross-Attention Layers**: Allow query tokens to attend to input features from another modality

The rest of BERT architecture is not used in Qformer.
- **Word/Position Embeddings**: Removed
- **Layer Outputs/Intermediates (FFN)**: Removed
- **BertPooler**: Not used
- **CLS head**: Removed

The Qformer can operate at different levels:
- **Window-level**: Processing fixed-length windows of speech/audio features
- **Sequence-level**: Processing the entire sequence at once

## 4. LLaMA Architecture

LLaMA (Large Language Model Meta AI) is the foundation language model used in SALMONN for text generation. It's a decoder-only transformer architecture optimized for efficient inference and strong language understanding capabilities.

### 4.1 Class Diagram

#### Class Inheritance Hierarchy
```
PreTrainedModel
└──> LlamaPreTrainedModel
└──> LlamaForCausalLM
```

#### Component Composition Structure
```
LlamaForCausalLM
├──> LlamaModel
│ │
│ ├──> nn.Embedding (embed_tokens)
│ │
│ ├──> nn.ModuleList of LlamaDecoderLayer
│ │ │
│ │ ├──> LlamaAttention
│ │ │ │
│ │ │ ├──> Linear projections (q_proj, k_proj, v_proj, o_proj)
│ │ │ └──> LlamaRotaryEmbedding (rotary_emb)
│ │ │
│ │ ├──> LlamaRMSNorm (input_layernorm)
│ │ │
│ │ ├──> LlamaMLP
│ │ │ │
│ │ │ ├──> Linear (gate_proj)
│ │ │ ├──> Linear (up_proj)
│ │ │ ├──> ACT2FN[hidden_act] activation
│ │ │ └──> Linear (down_proj)
│ │ │
│ │ └──> LlamaRMSNorm (post_attention_layernorm)
│ │
│ └──> LlamaRMSNorm (norm)
└──> nn.Linear (lm_head)
```

In SALMONN, LLaMA is integrated with speech features through a projection layer that maps Qformer outputs to the LLaMA embedding space. The model processes inputs in this sequence:

1. BOS token embeddings
2. Projected speech embeddings from Qformer
3. Text token embeddings (for training)
Loading