Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__/
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,27 @@ In order to access each model:

Once you run those steps, the commands below to run the LayerSkip checkpoints should work.

## Fine Tune

To train any supported HuggingFace model with the LayerSkip approach:
```bash
torchrun finetune_layerskip.py \
--ckpt facebook/llama2-7B \
--ds_ckpt some_dataset \
--template "###INST: {utterance}\n\n###RES: {semantic_parse}" \
--lr 1e-4 \
--batch_size 8 \
--epochs 3 \
--early_exit_loss_scale 1.0 \
--eval_freq 50 \
--output_dir ./checkpoints
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating here the command to train.py. But can you also check that the command works? Like we can put a command that users can copy and paste to their command line (e.g., put the TopV2 dataset rather than some_dataset

Suggested change
## Fine Tune
To train any supported HuggingFace model with the LayerSkip approach:
```bash
torchrun finetune_layerskip.py \
--ckpt facebook/llama2-7B \
--ds_ckpt some_dataset \
--template "###INST: {utterance}\n\n###RES: {semantic_parse}" \
--lr 1e-4 \
--batch_size 8 \
--epochs 3 \
--early_exit_loss_scale 1.0 \
--eval_freq 50 \
--output_dir ./checkpoints
```
## Train
To train any supported HuggingFace model with the LayerSkip approach:
```bash
torchrun train.py \
--ckpt meta-llama/Llama-2-7b-hf \
--ds_ckpt some_dataset \
--template "###INST: {utterance}\n\n###RES: {semantic_parse}" \
--lr 1e-4 \
--batch_size 8 \
--epochs 3 \
--early_exit_loss_scale 1.0 \
--eval_freq 50 \
--output_dir ./checkpoints

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check and update.


Tips:
- **Adjust your hyperparameters**: Play with `--early_exit_loss_scale` to increase or decrease the importance of the early-exit path.
- **Monitor speed & memory usage**: Even during training, certain intermediate layers might lead to more efficient training if configured properly.
- **Run `python finetune_layerskip.py --help`** to see all command-line options.

## Generate

To run one of our models in interactive mode using regular autoregressive decoding:
Expand Down
204 changes: 204 additions & 0 deletions fine-tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from dataclasses import dataclass
from functools import partial

import torch
from torch import nn
from torch.nn import functional
from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from tqdm import tqdm

import os

@dataclass
class FineTuneArguments:
ckpt: str = "meta-llama/Llama-2-7b-hf" # <--- This tokenizer has the add_eos and add_bos feature not 3.2 (be careful)
ds_ckpt: str = "WillHeld/top_v2"
template: str = "### Instruction: {utterance}\n ### Response: {semantic_parse}"
lr: float = 2e-5
batch_size: int = 8
epochs: int = 1
eval_freq: int = 5000
early_exit_loss_scale: float = 1.0
save_steps: int = 5000
output_dir: str = "./checkpoints/"
hub_id: str = None

args = FineTuneArguments()

class LayerSkipModel(nn.Module):
def __init__(self, model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = model
self.num_layers = model.config.num_hidden_layers
self.early_exit_layer = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to do it in this PR, but in the future, self.early_exit_layer could be a list of layers

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting idea. If I am understanding this correctly you mean we could have a list of indices for the layer we want to exit early self.early_exit_layers=[0, 4, 8]?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That is actually what we referred to as "rotational curriculum" in the paper.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can create an issue and then do it in another PR if you want.


def forward(self, input_ids, attention_mask):
# If there are N layers, there are N+1 hidden states [l=0, l=N]
# The zero th hidden state (l=0) is input to the embedding layer
# The last hidden state (l=N) is the normalized output of the final layer
# We need to early exit from layers [l=1, l=N-1] both inclusive
self.early_exit_layer = (self.early_exit_layer % (self.num_layers - 1)) + 1

# Get the output logits and hidden states
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
logits = outputs["logits"]
hidden_states = outputs["hidden_states"]

# Select the exit hidden state and normalize it
exit_state = hidden_states[self.early_exit_layer]
exit_state = self.model.model.norm(exit_state)
exit_logits = self.model.lm_head(exit_state)

return logits, exit_logits

def collate_fn(batch, template):
return [template.format(**sample) for sample in batch]


def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args):
global_step = 0
for epoch in range(args.epochs):
trainer.train()
for step, batch in tqdm(enumerate(train_dl), total=len(train_dl)):
inputs = tokenizer(
batch,
return_tensors="pt",
padding=True,
)

input_ids = inputs["input_ids"]
input_attn_mask = inputs["attention_mask"]
labels = inputs["input_ids"].masked_fill(~input_attn_mask.bool(), -100)

input_ids = input_ids.to(device)
input_attn_mask = input_attn_mask.to(device)
labels = labels.to(device)

logits, exit_logits = trainer(
input_ids=input_ids, attention_mask=input_attn_mask
)
orig_loss = trainer.model.loss_function(
logits=logits, labels=labels, vocab_size=trainer.model.vocab_size,
)
exit_loss = trainer.model.loss_function(
logits=exit_logits, labels=labels, vocab_size=trainer.model.vocab_size,
)
total_scale = 1.0 + args.early_exit_loss_scale
total_loss = (
1.0 * orig_loss + args.early_exit_loss_scale * exit_loss
) / total_scale

optimizer.zero_grad()
total_loss.backward()
optimizer.step()
global_step += 1

if global_step % args.eval_freq == 0:
trainer.eval()
eval_loss = 0.0
num_val_steps = 0
with torch.no_grad():
for val_batch in tqdm(val_dl, total=len(val_dl)):
inputs = tokenizer(val_batch, return_tensors="pt", padding=True)

input_ids = inputs["input_ids"]
input_attn_mask = inputs["attention_mask"]
labels = inputs["input_ids"].masked_fill(~input_attn_mask.bool(), -100)

input_ids = input_ids.to(device)
input_attn_mask = input_attn_mask.to(device)
labels = labels.to(device)

logits, exit_logits = trainer(
input_ids=input_ids, attention_mask=input_attn_mask
)
orig_loss = trainer.model.loss_function(
logits=logits,
labels=labels,
vocab_size=trainer.model.vocab_size,
)
exit_loss = trainer.model.loss_function(
logits=exit_logits,
labels=labels,
vocab_size=trainer.model.vocab_size,
)
total_scale = 1.0 + args.early_exit_loss_scale
loss = (
1.0 * orig_loss + args.early_exit_loss_scale * exit_loss
) / total_scale

eval_loss += loss.item()
num_val_steps += 1

print(
f"Epoch {epoch}, Step {global_step}: "
f"Train Loss: {total_loss.item():.4f}, "
f"Val Loss: {eval_loss / num_val_steps:.4f}"
)
trainer.train()

if global_step % args.save_steps == 0:
step_dir = f"{args.output_dir}/step_{global_step}"
os.makedirs(step_dir, exist_ok=True)

model_path = f"{step_dir}/model"
trainer.model.save_pretrained(model_path)
print(f"Saved pretrained model to {model_path}")

checkpoint_path = f"{step_dir}/checkpoint.pt"
torch.save(
{
"step": global_step,
"epoch": epoch,
"model_state_dict": trainer.model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
checkpoint_path,
)
print(f"Saved training checkpoint to {checkpoint_path}")

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(args.ckpt)
tokenizer.pad_token = tokenizer.eos_token

# This is only true for llama 2 moedels (check the args for ckpt)
tokenizer.add_bos_token = True # This defaults to True
tokenizer.add_eos_token = True # This defaults to False, setting it to True will add eos token to each sample

model = AutoModelForCausalLM.from_pretrained(args.ckpt, torch_dtype="bfloat16", device_map="auto")

trainer = LayerSkipModel(model=model)
optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)

train_ds = load_dataset(args.ds_ckpt, split="train")
val_ds = load_dataset(args.ds_ckpt, split="eval")

collate_fn_with_template = partial(collate_fn, template=args.template)
train_dl = DataLoader(
train_ds,
batch_size=args.batch_size,
collate_fn=collate_fn_with_template,
shuffle=True,
)
val_dl = DataLoader(
val_ds,
batch_size=args.batch_size,
collate_fn=collate_fn_with_template,
)

trainer.to(device)
trainer.train()

train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args)

if args.hub_id:
tokenizer.push_to_hub(args.hub_id)
trainer.model.push_to_hub(args.hub_id)