From 44dd069ea33c450fd5aff415b666ef55a4701311 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Fri, 17 Jan 2025 23:17:28 +0530 Subject: [PATCH 1/8] chore: init fine tune --- fine-tune.py | 129 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 fine-tune.py diff --git a/fine-tune.py b/fine-tune.py new file mode 100644 index 0000000..db76e78 --- /dev/null +++ b/fine-tune.py @@ -0,0 +1,129 @@ +import torch +from torch import nn +from torch.utils.data import DataLoader +from datasets import load_dataset +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + + +class LayerSkipModel(nn.Module): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model = model + self.num_layers = model.config.num_hidden_layers + self.early_exit_layer = 0 + + def forward(self, input_ids, attention_mask): + # If there are N layers, there are N+1 hidden states [l=0, l=N] + # The zero th hidden state (l=0) is input to the embedding layer + # The last hidden state (l=N) is the normalized output of the final layer + # We need to early exit from layers [l=1, l=N-1] both inclusive + self.early_exit_layer = (self.early_exit_layer % (self.num_layers - 1)) + 1 + + # Get the output logits and hidden states + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + logits = outputs["logits"] + hidden_states = outputs["hidden_states"] + + # Select the exit hidden state and normalize it + exit_state = hidden_states[self.early_exit_layer] + exit_state = self.model.model.norm(exit_state) + exit_logits = self.model.lm_head(exit_state) + + return logits, exit_logits + + +def collate_fn(batch): + formatted_batch = [ + f"###INST: {sample['utterance']}\n\n###RES: {sample['semantic_parse']}" + for sample in batch + ] + return formatted_batch + + +if __name__ == "__main__": + ckpt = "meta-llama/Llama-3.2-1B" + ds_ckpt = "WillHeld/top_v2" + lr = 1e-3 + batch_size = 8 + epochs = 1 + device = "cuda" + + tokenizer = AutoTokenizer.from_pretrained(ckpt) + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained(ckpt) + + trainer = LayerSkipModel(model=model) + optimizer = torch.optim.Adam(params=model.parameters(), lr=lr) + + train_ds = load_dataset(ds_ckpt, split="train") + val_ds = load_dataset(ds_ckpt, split="eval") + + train_dl = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn) + val_dl = DataLoader(val_ds, batch_size=batch_size, collate_fn=collate_fn) + + trainer.to(device) + trainer.train() + + for idx in range(epochs): + for idx, batch in enumerate(train_dl): + inputs = tokenizer(batch, return_tensors="pt", padding=True) + + input_ids = inputs["input_ids"][:, :-1].to(device) + input_attn_mask = inputs["attention_mask"][:, :-1].to(device) + + labels = inputs["input_ids"][:, 1:].to(device) + + logits, exit_logits = trainer( + input_ids=input_ids, attention_mask=input_attn_mask + ) + orig_loss = trainer.model.loss_function( + logits=logits, labels=labels, vocab_size=trainer.model.vocab_size + ) + exit_loss = trainer.model.loss_function( + logits=exit_logits, labels=labels, vocab_size=trainer.model.vocab_size + ) + loss = orig_loss + exit_loss + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if idx % 100 == 0: + eval_loss = 0 + trainer.eval() + with torch.no_grad(): + for val_idx, val_batch in enumerate(val_dl): + inputs = tokenizer(val_batch, return_tensors="pt", padding=True) + + input_ids = inputs["input_ids"][:, :-1].to(device) + input_attn_mask = inputs["attention_mask"][:, :-1].to(device) + + labels = inputs["input_ids"][:, 1:].to(device) + + logits, exit_logits = trainer( + input_ids=input_ids, attention_mask=input_attn_mask + ) + orig_loss = trainer.model.loss_function( + logits=logits, + labels=labels, + vocab_size=trainer.model.vocab_size, + ) + exit_loss = trainer.model.loss_function( + logits=exit_logits, + labels=labels, + vocab_size=trainer.model.vocab_size, + ) + loss = orig_loss + exit_loss + + eval_loss += loss.item() + + print( + f"Epoch: {idx}, Train Loss: {loss.item():0.2f} Val Loss: {eval_loss / (val_idx - 1):0.2f}" + ) + trainer.train() From 6bbd682d049affe2d93927c64fe3d6a65502de1b Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Mon, 20 Jan 2025 10:34:58 +0530 Subject: [PATCH 2/8] chore: review suggestions --- .gitignore | 1 + README.md | 21 +++++++ fine-tune.py | 158 ++++++++++++++++++++++++++++++++++++--------------- 3 files changed, 135 insertions(+), 45 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ba0430d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ \ No newline at end of file diff --git a/README.md b/README.md index 0dd9308..7bb950b 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,27 @@ In order to access each model: Once you run those steps, the commands below to run the LayerSkip checkpoints should work. +## Fine Tune + +To train any supported HuggingFace model with the LayerSkip approach: +```bash +torchrun finetune_layerskip.py \ + --ckpt facebook/llama2-7B \ + --ds_ckpt some_dataset \ + --template "###INST: {utterance}\n\n###RES: {semantic_parse}" \ + --lr 1e-4 \ + --batch_size 8 \ + --epochs 3 \ + --early_exit_loss_scale 1.0 \ + --eval_freq 50 \ + --output_dir ./checkpoints +``` + +Tips: +- **Adjust your hyperparameters**: Play with `--early_exit_loss_scale` to increase or decrease the importance of the early-exit path. +- **Monitor speed & memory usage**: Even during training, certain intermediate layers might lead to more efficient training if configured properly. +- **Run `python finetune_layerskip.py --help`** to see all command-line options. + ## Generate To run one of our models in interactive mode using regular autoregressive decoding: diff --git a/fine-tune.py b/fine-tune.py index db76e78..a42a428 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -1,9 +1,31 @@ +import logging +from dataclasses import dataclass +from functools import partial + import torch from torch import nn +from torch.nn import functional from torch.utils.data import DataLoader + from datasets import load_dataset -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@dataclass +class FineTuneArguments: + ckpt: str = "meta-llama/Llama-3.2-1B" + ds_ckpt: str = "WillHeld/top_v2" + template: str = "###INST: {utterance}\n\n###RES: {semantic_parse}" + lr: float = 1e-3 + batch_size: int = 8 + epochs: int = 1 + eval_freq: int = 100 + early_exit_loss_scale: float = 1.0 + save_steps: int = 500 + output_dir: str = "./checkpoints" class LayerSkipModel(nn.Module): @@ -37,43 +59,18 @@ def forward(self, input_ids, attention_mask): return logits, exit_logits -def collate_fn(batch): - formatted_batch = [ - f"###INST: {sample['utterance']}\n\n###RES: {sample['semantic_parse']}" - for sample in batch - ] - return formatted_batch - - -if __name__ == "__main__": - ckpt = "meta-llama/Llama-3.2-1B" - ds_ckpt = "WillHeld/top_v2" - lr = 1e-3 - batch_size = 8 - epochs = 1 - device = "cuda" - - tokenizer = AutoTokenizer.from_pretrained(ckpt) - tokenizer.pad_token = tokenizer.eos_token - - model = AutoModelForCausalLM.from_pretrained(ckpt) - - trainer = LayerSkipModel(model=model) - optimizer = torch.optim.Adam(params=model.parameters(), lr=lr) +def collate_fn(batch, template): + return [template.format(**sample) for sample in batch] - train_ds = load_dataset(ds_ckpt, split="train") - val_ds = load_dataset(ds_ckpt, split="eval") - - train_dl = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn) - val_dl = DataLoader(val_ds, batch_size=batch_size, collate_fn=collate_fn) - - trainer.to(device) - trainer.train() - - for idx in range(epochs): - for idx, batch in enumerate(train_dl): - inputs = tokenizer(batch, return_tensors="pt", padding=True) +def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args): + global_step = 0 + for epoch in range(args.epochs): + trainer.train() + for step, batch in enumerate(train_dl): + inputs = tokenizer( + batch, return_tensors="pt", padding=True, truncation=True + ) input_ids = inputs["input_ids"][:, :-1].to(device) input_attn_mask = inputs["attention_mask"][:, :-1].to(device) @@ -88,17 +85,22 @@ def collate_fn(batch): exit_loss = trainer.model.loss_function( logits=exit_logits, labels=labels, vocab_size=trainer.model.vocab_size ) - loss = orig_loss + exit_loss + total_scale = 1.0 + args.early_exit_loss_scale + total_loss = ( + 1.0 * orig_loss + args.early_exit_loss_scale * exit_loss + ) / total_scale optimizer.zero_grad() - loss.backward() + total_loss.backward() optimizer.step() + global_step += 1 - if idx % 100 == 0: - eval_loss = 0 + if global_step % args.eval_freq == 0: trainer.eval() + eval_loss = 0.0 + num_val_steps = 0 with torch.no_grad(): - for val_idx, val_batch in enumerate(val_dl): + for val_batch in val_dl: inputs = tokenizer(val_batch, return_tensors="pt", padding=True) input_ids = inputs["input_ids"][:, :-1].to(device) @@ -119,11 +121,77 @@ def collate_fn(batch): labels=labels, vocab_size=trainer.model.vocab_size, ) - loss = orig_loss + exit_loss + total_scale = 1.0 + finetune_arguments.early_exit_loss_scale + loss = ( + 1.0 * orig_loss + + finetune_arguments.early_exit_loss_scale * exit_loss + ) / total_scale eval_loss += loss.item() + num_val_steps += 1 - print( - f"Epoch: {idx}, Train Loss: {loss.item():0.2f} Val Loss: {eval_loss / (val_idx - 1):0.2f}" + logger.info( + f"Epoch {epoch}, Step {global_step}: " + f"Train Loss: {total_loss.item():.4f}, " + f"Val Loss: {eval_loss / num_val_steps:.4f}" ) trainer.train() + + if global_step % args.save_steps == 0: + checkpoint_path = f"{args.output_dir}/checkpoint_{global_step}.pt" + torch.save( + { + "step": global_step, + "epoch": epoch, + "model_state_dict": trainer.model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + }, + checkpoint_path, + ) + logger.info(f"Saved checkpoint to {checkpoint_path}") + + +def main(finetune_arguments): + device = "cuda" if torch.cuda.is_available() else "cpu" + tokenizer = AutoTokenizer.from_pretrained(finetune_arguments.ckpt) + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained(finetune_arguments.ckpt) + + trainer = LayerSkipModel(model=model) + optimizer = torch.optim.Adam(params=model.parameters(), lr=finetune_arguments.lr) + + train_ds = load_dataset(finetune_arguments.ds_ckpt, split="train") + val_ds = load_dataset(finetune_arguments.ds_ckpt, split="eval") + + collate_fn_with_template = partial(collate_fn, template=finetune_arguments.template) + train_dl = DataLoader( + train_ds, + batch_size=finetune_arguments.batch_size, + collate_fn=collate_fn_with_template, + ) + val_dl = DataLoader( + val_ds, + batch_size=finetune_arguments.batch_size, + collate_fn=collate_fn_with_template, + ) + + trainer.to(device) + trainer.train() + + train_and_eval( + train_dl, val_dl, tokenizer, device, trainer, optimizer, finetune_arguments + ) + + +def process_cli_arguments(): + parser = HfArgumentParser((FineTuneArguments)) + finetune_arguments = parser.parse_args_into_dataclasses( + return_remaining_strings=False + ) + return finetune_arguments + + +if __name__ == "__main__": + finetune_arguments = process_cli_arguments() + main(finetune_arguments) From 458ffd65d7dd9d9ccfcd2c1af9c7e5e2f293a210 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Mon, 20 Jan 2025 21:34:54 +0530 Subject: [PATCH 3/8] add eos token --- fine-tune.py | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index a42a428..8898847 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -69,7 +69,9 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args trainer.train() for step, batch in enumerate(train_dl): inputs = tokenizer( - batch, return_tensors="pt", padding=True, truncation=True + batch, + return_tensors="pt", + padding=True, ) input_ids = inputs["input_ids"][:, :-1].to(device) input_attn_mask = inputs["attention_mask"][:, :-1].to(device) @@ -121,10 +123,9 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args labels=labels, vocab_size=trainer.model.vocab_size, ) - total_scale = 1.0 + finetune_arguments.early_exit_loss_scale + total_scale = 1.0 + args.early_exit_loss_scale loss = ( - 1.0 * orig_loss - + finetune_arguments.early_exit_loss_scale * exit_loss + 1.0 * orig_loss + args.early_exit_loss_scale * exit_loss ) / total_scale eval_loss += loss.item() @@ -151,47 +152,46 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args logger.info(f"Saved checkpoint to {checkpoint_path}") -def main(finetune_arguments): +def main(args): device = "cuda" if torch.cuda.is_available() else "cpu" - tokenizer = AutoTokenizer.from_pretrained(finetune_arguments.ckpt) - tokenizer.pad_token = tokenizer.eos_token + tokenizer = AutoTokenizer.from_pretrained(args.ckpt) + tokenizer.add_special_tokens({"pad_token": ""}) # Add pad token - model = AutoModelForCausalLM.from_pretrained(finetune_arguments.ckpt) + tokenizer.add_bos_token = True # This defaults to True + tokenizer.add_eos_token = True # This defaults to False, setting it to True will add eos token to each sample + + model = AutoModelForCausalLM.from_pretrained(args.ckpt) trainer = LayerSkipModel(model=model) - optimizer = torch.optim.Adam(params=model.parameters(), lr=finetune_arguments.lr) + optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) - train_ds = load_dataset(finetune_arguments.ds_ckpt, split="train") - val_ds = load_dataset(finetune_arguments.ds_ckpt, split="eval") + train_ds = load_dataset(args.ds_ckpt, split="train") + val_ds = load_dataset(args.ds_ckpt, split="eval") - collate_fn_with_template = partial(collate_fn, template=finetune_arguments.template) + collate_fn_with_template = partial(collate_fn, template=args.template) train_dl = DataLoader( train_ds, - batch_size=finetune_arguments.batch_size, + batch_size=args.batch_size, collate_fn=collate_fn_with_template, ) val_dl = DataLoader( val_ds, - batch_size=finetune_arguments.batch_size, + batch_size=args.batch_size, collate_fn=collate_fn_with_template, ) trainer.to(device) trainer.train() - train_and_eval( - train_dl, val_dl, tokenizer, device, trainer, optimizer, finetune_arguments - ) + train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args) def process_cli_arguments(): parser = HfArgumentParser((FineTuneArguments)) - finetune_arguments = parser.parse_args_into_dataclasses( - return_remaining_strings=False - ) - return finetune_arguments + args = parser.parse_args_into_dataclasses(return_remaining_strings=False) + return args if __name__ == "__main__": - finetune_arguments = process_cli_arguments() - main(finetune_arguments) + args = process_cli_arguments() + main(args) From 30f04e93f10008e2a3ad30b1d623b56ec13d7307 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Mon, 20 Jan 2025 22:55:31 +0530 Subject: [PATCH 4/8] fix the eos and pad issue --- fine-tune.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index 8898847..241be21 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -155,7 +155,7 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args def main(args): device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(args.ckpt) - tokenizer.add_special_tokens({"pad_token": ""}) # Add pad token + tokenizer.pad_token = tokenizer.eos_token tokenizer.add_bos_token = True # This defaults to True tokenizer.add_eos_token = True # This defaults to False, setting it to True will add eos token to each sample @@ -187,9 +187,9 @@ def main(args): def process_cli_arguments(): - parser = HfArgumentParser((FineTuneArguments)) + parser = HfArgumentParser(FineTuneArguments) args = parser.parse_args_into_dataclasses(return_remaining_strings=False) - return args + return args[0] if __name__ == "__main__": From 4474cd8d4663902a325c265e504a0c3b2b17986f Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Tue, 21 Jan 2025 01:29:35 +0530 Subject: [PATCH 5/8] bug fix 1. added ignore index in loss function 2. inputs and labels are same, as the offset is done inside the loss function of the causal model --- fine-tune.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index 241be21..e521752 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -2,9 +2,10 @@ from dataclasses import dataclass from functools import partial +import os + import torch from torch import nn -from torch.nn import functional from torch.utils.data import DataLoader from datasets import load_dataset @@ -73,19 +74,21 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args return_tensors="pt", padding=True, ) - input_ids = inputs["input_ids"][:, :-1].to(device) - input_attn_mask = inputs["attention_mask"][:, :-1].to(device) + input_ids = inputs["input_ids"].to(device) + input_attn_mask = inputs["attention_mask"].to(device) - labels = inputs["input_ids"][:, 1:].to(device) + labels = inputs["input_ids"].to(device) logits, exit_logits = trainer( input_ids=input_ids, attention_mask=input_attn_mask ) orig_loss = trainer.model.loss_function( - logits=logits, labels=labels, vocab_size=trainer.model.vocab_size + logits=logits, labels=labels, vocab_size=trainer.model.vocab_size, + ignore_index=tokenizer.pad_token_id, ) exit_loss = trainer.model.loss_function( - logits=exit_logits, labels=labels, vocab_size=trainer.model.vocab_size + logits=exit_logits, labels=labels, vocab_size=trainer.model.vocab_size, + ignore_index=tokenizer.pad_token_id, ) total_scale = 1.0 + args.early_exit_loss_scale total_loss = ( @@ -105,10 +108,10 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args for val_batch in val_dl: inputs = tokenizer(val_batch, return_tensors="pt", padding=True) - input_ids = inputs["input_ids"][:, :-1].to(device) - input_attn_mask = inputs["attention_mask"][:, :-1].to(device) + input_ids = inputs["input_ids"].to(device) + input_attn_mask = inputs["attention_mask"].to(device) - labels = inputs["input_ids"][:, 1:].to(device) + labels = inputs["input_ids"].to(device) logits, exit_logits = trainer( input_ids=input_ids, attention_mask=input_attn_mask @@ -117,11 +120,13 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args logits=logits, labels=labels, vocab_size=trainer.model.vocab_size, + ignore_index=tokenizer.pad_token_id, ) exit_loss = trainer.model.loss_function( logits=exit_logits, labels=labels, vocab_size=trainer.model.vocab_size, + ignore_index=tokenizer.pad_token_id, ) total_scale = 1.0 + args.early_exit_loss_scale loss = ( @@ -139,6 +144,9 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args trainer.train() if global_step % args.save_steps == 0: + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + checkpoint_path = f"{args.output_dir}/checkpoint_{global_step}.pt" torch.save( { From 8e9ecc2fad0745904ed715f09abcd53567191af3 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Wed, 22 Jan 2025 00:31:51 -0500 Subject: [PATCH 6/8] Add tqdm progress bar --- fine-tune.py | 123 +++++++++++++++++++++++++-------------------------- 1 file changed, 61 insertions(+), 62 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index e521752..7631af5 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -1,33 +1,32 @@ -import logging from dataclasses import dataclass from functools import partial -import os - import torch from torch import nn +from torch.nn import functional from torch.utils.data import DataLoader from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +from tqdm import tqdm +import os @dataclass class FineTuneArguments: - ckpt: str = "meta-llama/Llama-3.2-1B" + ckpt: str = "meta-llama/Llama-2-7b-hf" # <--- This tokenizer has the add_eos and add_bos feature not 3.2 (be careful) ds_ckpt: str = "WillHeld/top_v2" - template: str = "###INST: {utterance}\n\n###RES: {semantic_parse}" - lr: float = 1e-3 + template: str = "### Instruction: {utterance}\n ### Response: {semantic_parse}" + lr: float = 1e-4 batch_size: int = 8 epochs: int = 1 - eval_freq: int = 100 + eval_freq: int = 2000 early_exit_loss_scale: float = 1.0 - save_steps: int = 500 + save_steps: int = 1000 output_dir: str = "./checkpoints" +args = FineTuneArguments() class LayerSkipModel(nn.Module): def __init__(self, model, *args, **kwargs): @@ -59,7 +58,6 @@ def forward(self, input_ids, attention_mask): return logits, exit_logits - def collate_fn(batch, template): return [template.format(**sample) for sample in batch] @@ -68,27 +66,29 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args global_step = 0 for epoch in range(args.epochs): trainer.train() - for step, batch in enumerate(train_dl): + for step, batch in tqdm(enumerate(train_dl), total=len(train_dl)): inputs = tokenizer( batch, return_tensors="pt", padding=True, ) - input_ids = inputs["input_ids"].to(device) - input_attn_mask = inputs["attention_mask"].to(device) - labels = inputs["input_ids"].to(device) + input_ids = inputs["input_ids"] + input_attn_mask = inputs["attention_mask"] + labels = inputs["input_ids"].masked_fill(~input_attn_mask.bool(), -100) + + input_ids = input_ids.to(device) + input_attn_mask = input_attn_mask.to(device) + labels = labels.to(device) logits, exit_logits = trainer( input_ids=input_ids, attention_mask=input_attn_mask ) orig_loss = trainer.model.loss_function( logits=logits, labels=labels, vocab_size=trainer.model.vocab_size, - ignore_index=tokenizer.pad_token_id, ) exit_loss = trainer.model.loss_function( logits=exit_logits, labels=labels, vocab_size=trainer.model.vocab_size, - ignore_index=tokenizer.pad_token_id, ) total_scale = 1.0 + args.early_exit_loss_scale total_loss = ( @@ -105,13 +105,22 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args eval_loss = 0.0 num_val_steps = 0 with torch.no_grad(): - for val_batch in val_dl: + # prompt = "###INST: {utterance}\n\n###RES: ".format(utterance="What time is my bed time alarm set for?") + # tokenizer.add_eos_token = False # For open ended generation + # inputs = tokenizer(prompt, return_tensors="pt").to(device) + # outputs = trainer.model.generate(**inputs, assistant_early_exit=trainer.early_exit_layer, max_new_tokens=40) + # print(tokenizer.batch_decode(outputs)) + # tokenizer.add_eos_token = True # Turn to True for validation and training + for val_batch in tqdm(val_dl, total=len(val_dl)): inputs = tokenizer(val_batch, return_tensors="pt", padding=True) - input_ids = inputs["input_ids"].to(device) - input_attn_mask = inputs["attention_mask"].to(device) + input_ids = inputs["input_ids"] + input_attn_mask = inputs["attention_mask"] + labels = inputs["input_ids"].masked_fill(~input_attn_mask.bool(), -100) - labels = inputs["input_ids"].to(device) + input_ids = input_ids.to(device) + input_attn_mask = input_attn_mask.to(device) + labels = labels.to(device) logits, exit_logits = trainer( input_ids=input_ids, attention_mask=input_attn_mask @@ -120,13 +129,11 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args logits=logits, labels=labels, vocab_size=trainer.model.vocab_size, - ignore_index=tokenizer.pad_token_id, ) exit_loss = trainer.model.loss_function( logits=exit_logits, labels=labels, vocab_size=trainer.model.vocab_size, - ignore_index=tokenizer.pad_token_id, ) total_scale = 1.0 + args.early_exit_loss_scale loss = ( @@ -136,7 +143,7 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args eval_loss += loss.item() num_val_steps += 1 - logger.info( + print( f"Epoch {epoch}, Step {global_step}: " f"Train Loss: {total_loss.item():.4f}, " f"Val Loss: {eval_loss / num_val_steps:.4f}" @@ -146,7 +153,7 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args if global_step % args.save_steps == 0: if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - + checkpoint_path = f"{args.output_dir}/checkpoint_{global_step}.pt" torch.save( { @@ -157,49 +164,41 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args }, checkpoint_path, ) - logger.info(f"Saved checkpoint to {checkpoint_path}") - - -def main(args): - device = "cuda" if torch.cuda.is_available() else "cpu" - tokenizer = AutoTokenizer.from_pretrained(args.ckpt) - tokenizer.pad_token = tokenizer.eos_token - - tokenizer.add_bos_token = True # This defaults to True - tokenizer.add_eos_token = True # This defaults to False, setting it to True will add eos token to each sample - - model = AutoModelForCausalLM.from_pretrained(args.ckpt) + print(f"Saved checkpoint to {checkpoint_path}") - trainer = LayerSkipModel(model=model) - optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) +device = "cuda" if torch.cuda.is_available() else "cpu" +tokenizer = AutoTokenizer.from_pretrained(args.ckpt) +tokenizer.pad_token = tokenizer.eos_token - train_ds = load_dataset(args.ds_ckpt, split="train") - val_ds = load_dataset(args.ds_ckpt, split="eval") +# This is only true for llama 2 moedels (check the args for ckpt) +tokenizer.add_bos_token = True # This defaults to True +tokenizer.add_eos_token = True # This defaults to False, setting it to True will add eos token to each sample - collate_fn_with_template = partial(collate_fn, template=args.template) - train_dl = DataLoader( - train_ds, - batch_size=args.batch_size, - collate_fn=collate_fn_with_template, - ) - val_dl = DataLoader( - val_ds, - batch_size=args.batch_size, - collate_fn=collate_fn_with_template, - ) +model = AutoModelForCausalLM.from_pretrained(args.ckpt, torch_dtype="bfloat16", device_map="auto") - trainer.to(device) - trainer.train() +trainer = LayerSkipModel(model=model) +optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) - train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args) +train_ds = load_dataset(args.ds_ckpt, split="train")# load_dataset(args.ds_ckpt, split="train[:1000]") +val_ds = load_dataset(args.ds_ckpt, split="eval") # load_dataset(args.ds_ckpt, split="eval[:50]") +collate_fn_with_template = partial(collate_fn, template=args.template) +train_dl = DataLoader( + train_ds, + batch_size=args.batch_size, + collate_fn=collate_fn_with_template, + shuffle=True, +) +val_dl = DataLoader( + val_ds, + batch_size=args.batch_size, + collate_fn=collate_fn_with_template, +) -def process_cli_arguments(): - parser = HfArgumentParser(FineTuneArguments) - args = parser.parse_args_into_dataclasses(return_remaining_strings=False) - return args[0] +trainer.to(device) +trainer.train() +train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args) -if __name__ == "__main__": - args = process_cli_arguments() - main(args) +tokenizer.push_to_hub("ariG23498/layer-skip-vanill-2-7b") +trainer.model.push_to_hub("ariG23498/layer-skip-vanill-2-7b") From 70f6eadbb26252b2e915067b7aa53eb203167ff3 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Thu, 23 Jan 2025 00:40:24 -0500 Subject: [PATCH 7/8] Update fine-tune.py that leads to 2.2x speedup Leads to: ``` ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Orig Time: 0.8662526607513428 From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`. ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.5897870063781738 For Layer: 1 Speedup: 1.4687550783305965 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.35465383529663086 For Layer: 2 Speedup: 2.4425300801464984 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.36078310012817383 For Layer: 3 Speedup: 2.4010344731878877 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3151836395263672 For Layer: 4 Speedup: 2.748406173788329 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3399159908294678 For Layer: 5 Speedup: 2.5484316246420207 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3960268497467041 For Layer: 6 Speedup: 2.1873584109395403 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.38035011291503906 For Layer: 7 Speedup: 2.2775138782326128 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.4112529754638672 For Layer: 8 Speedup: 2.106374208658952 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.3891477584838867 For Layer: 9 Speedup: 2.226025055691568 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.41824865341186523 For Layer: 10 Speedup: 2.0711427369457924 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.43691492080688477 For Layer: 11 Speedup: 1.9826575369675328 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.44853711128234863 For Layer: 12 Speedup: 1.931284254885316 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.45146870613098145 For Layer: 13 Speedup: 1.9187435341310743 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.4784407615661621 For Layer: 14 Speedup: 1.8105745378292801 ### Instruction: Set alarm for 6am every day ### Response: [IN:CREATE_ALARM Set alarm [SL:DATE_TIME_RECURRING for 6 am every day ] ] Layerskip Time: 0.4980161190032959 For Layer: 15 Speedup: 1.7394068739883695 ``` --- fine-tune.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/fine-tune.py b/fine-tune.py index 7631af5..f8d90cd 100644 --- a/fine-tune.py +++ b/fine-tune.py @@ -18,13 +18,14 @@ class FineTuneArguments: ckpt: str = "meta-llama/Llama-2-7b-hf" # <--- This tokenizer has the add_eos and add_bos feature not 3.2 (be careful) ds_ckpt: str = "WillHeld/top_v2" template: str = "### Instruction: {utterance}\n ### Response: {semantic_parse}" - lr: float = 1e-4 + lr: float = 2e-5 batch_size: int = 8 epochs: int = 1 - eval_freq: int = 2000 + eval_freq: int = 5000 early_exit_loss_scale: float = 1.0 - save_steps: int = 1000 - output_dir: str = "./checkpoints" + save_steps: int = 5000 + output_dir: str = "./checkpoints/" + hub_id: str = None args = FineTuneArguments() @@ -105,12 +106,6 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args eval_loss = 0.0 num_val_steps = 0 with torch.no_grad(): - # prompt = "###INST: {utterance}\n\n###RES: ".format(utterance="What time is my bed time alarm set for?") - # tokenizer.add_eos_token = False # For open ended generation - # inputs = tokenizer(prompt, return_tensors="pt").to(device) - # outputs = trainer.model.generate(**inputs, assistant_early_exit=trainer.early_exit_layer, max_new_tokens=40) - # print(tokenizer.batch_decode(outputs)) - # tokenizer.add_eos_token = True # Turn to True for validation and training for val_batch in tqdm(val_dl, total=len(val_dl)): inputs = tokenizer(val_batch, return_tensors="pt", padding=True) @@ -151,10 +146,14 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args trainer.train() if global_step % args.save_steps == 0: - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) + step_dir = f"{args.output_dir}/step_{global_step}" + os.makedirs(step_dir, exist_ok=True) - checkpoint_path = f"{args.output_dir}/checkpoint_{global_step}.pt" + model_path = f"{step_dir}/model" + trainer.model.save_pretrained(model_path) + print(f"Saved pretrained model to {model_path}") + + checkpoint_path = f"{step_dir}/checkpoint.pt" torch.save( { "step": global_step, @@ -164,7 +163,7 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args }, checkpoint_path, ) - print(f"Saved checkpoint to {checkpoint_path}") + print(f"Saved training checkpoint to {checkpoint_path}") device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(args.ckpt) @@ -179,8 +178,8 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args trainer = LayerSkipModel(model=model) optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) -train_ds = load_dataset(args.ds_ckpt, split="train")# load_dataset(args.ds_ckpt, split="train[:1000]") -val_ds = load_dataset(args.ds_ckpt, split="eval") # load_dataset(args.ds_ckpt, split="eval[:50]") +train_ds = load_dataset(args.ds_ckpt, split="train") +val_ds = load_dataset(args.ds_ckpt, split="eval") collate_fn_with_template = partial(collate_fn, template=args.template) train_dl = DataLoader( @@ -200,5 +199,6 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args) -tokenizer.push_to_hub("ariG23498/layer-skip-vanill-2-7b") -trainer.model.push_to_hub("ariG23498/layer-skip-vanill-2-7b") +if args.hub_id: + tokenizer.push_to_hub(args.hub_id) + trainer.model.push_to_hub(args.hub_id) From cd84f88c74db6ce61fc393e33a306d0151826de7 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Thu, 23 Jan 2025 11:30:39 +0530 Subject: [PATCH 8/8] add args and change readme --- README.md | 26 +++++------ fine-tune.py => train.py | 93 ++++++++++++++++++++++++---------------- 2 files changed, 68 insertions(+), 51 deletions(-) rename fine-tune.py => train.py (74%) diff --git a/README.md b/README.md index 7bb950b..f37f5f4 100644 --- a/README.md +++ b/README.md @@ -45,26 +45,22 @@ In order to access each model: Once you run those steps, the commands below to run the LayerSkip checkpoints should work. -## Fine Tune - +## Train To train any supported HuggingFace model with the LayerSkip approach: ```bash -torchrun finetune_layerskip.py \ - --ckpt facebook/llama2-7B \ - --ds_ckpt some_dataset \ - --template "###INST: {utterance}\n\n###RES: {semantic_parse}" \ - --lr 1e-4 \ +torchrun train.py \ + --ckpt "meta-llama/Llama-2-7b-hf" \ + --ds_ckpt "WillHeld/top_v2" \ + --template "### Instruction: {utterance}\n ### Response: {semantic_parse}" \ + --lr 2e-5 \ --batch_size 8 \ - --epochs 3 \ + --epochs 1 \ + --eval_freq 5000 \ --early_exit_loss_scale 1.0 \ - --eval_freq 50 \ - --output_dir ./checkpoints + --save_steps 5000 \ + --output_dir "./checkpoints/" \ + --hub_id "hf_id/Llama-2-7b-hf-layerskip" \ ``` - -Tips: -- **Adjust your hyperparameters**: Play with `--early_exit_loss_scale` to increase or decrease the importance of the early-exit path. -- **Monitor speed & memory usage**: Even during training, certain intermediate layers might lead to more efficient training if configured properly. -- **Run `python finetune_layerskip.py --help`** to see all command-line options. ## Generate diff --git a/fine-tune.py b/train.py similarity index 74% rename from fine-tune.py rename to train.py index f8d90cd..ac873c1 100644 --- a/fine-tune.py +++ b/train.py @@ -3,7 +3,6 @@ import torch from torch import nn -from torch.nn import functional from torch.utils.data import DataLoader from datasets import load_dataset @@ -13,9 +12,10 @@ import os + @dataclass class FineTuneArguments: - ckpt: str = "meta-llama/Llama-2-7b-hf" # <--- This tokenizer has the add_eos and add_bos feature not 3.2 (be careful) + ckpt: str = "meta-llama/Llama-2-7b-hf" # Need to keep a check with `add_eos_token` ds_ckpt: str = "WillHeld/top_v2" template: str = "### Instruction: {utterance}\n ### Response: {semantic_parse}" lr: float = 2e-5 @@ -27,7 +27,6 @@ class FineTuneArguments: output_dir: str = "./checkpoints/" hub_id: str = None -args = FineTuneArguments() class LayerSkipModel(nn.Module): def __init__(self, model, *args, **kwargs): @@ -59,6 +58,7 @@ def forward(self, input_ids, attention_mask): return logits, exit_logits + def collate_fn(batch, template): return [template.format(**sample) for sample in batch] @@ -67,7 +67,7 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args global_step = 0 for epoch in range(args.epochs): trainer.train() - for step, batch in tqdm(enumerate(train_dl), total=len(train_dl)): + for _, batch in tqdm(enumerate(train_dl), total=len(train_dl)): inputs = tokenizer( batch, return_tensors="pt", @@ -86,10 +86,14 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args input_ids=input_ids, attention_mask=input_attn_mask ) orig_loss = trainer.model.loss_function( - logits=logits, labels=labels, vocab_size=trainer.model.vocab_size, + logits=logits, + labels=labels, + vocab_size=trainer.model.vocab_size, ) exit_loss = trainer.model.loss_function( - logits=exit_logits, labels=labels, vocab_size=trainer.model.vocab_size, + logits=exit_logits, + labels=labels, + vocab_size=trainer.model.vocab_size, ) total_scale = 1.0 + args.early_exit_loss_scale total_loss = ( @@ -111,7 +115,9 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args input_ids = inputs["input_ids"] input_attn_mask = inputs["attention_mask"] - labels = inputs["input_ids"].masked_fill(~input_attn_mask.bool(), -100) + labels = inputs["input_ids"].masked_fill( + ~input_attn_mask.bool(), -100 + ) input_ids = input_ids.to(device) input_attn_mask = input_attn_mask.to(device) @@ -165,40 +171,55 @@ def train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args ) print(f"Saved training checkpoint to {checkpoint_path}") -device = "cuda" if torch.cuda.is_available() else "cpu" -tokenizer = AutoTokenizer.from_pretrained(args.ckpt) -tokenizer.pad_token = tokenizer.eos_token -# This is only true for llama 2 moedels (check the args for ckpt) -tokenizer.add_bos_token = True # This defaults to True -tokenizer.add_eos_token = True # This defaults to False, setting it to True will add eos token to each sample +def main(args): + device = "cuda" if torch.cuda.is_available() else "cpu" + tokenizer = AutoTokenizer.from_pretrained(args.ckpt) + tokenizer.pad_token = tokenizer.eos_token + + # This is only true for llama 2 moedels (check the args for ckpt) + tokenizer.add_bos_token = True # This defaults to True + tokenizer.add_eos_token = True # This defaults to False, setting it to True will add eos token to each sample + + model = AutoModelForCausalLM.from_pretrained( + args.ckpt, torch_dtype="bfloat16", device_map="auto" + ) + + trainer = LayerSkipModel(model=model) + optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) + + train_ds = load_dataset(args.ds_ckpt, split="train") + val_ds = load_dataset(args.ds_ckpt, split="eval") + + collate_fn_with_template = partial(collate_fn, template=args.template) + train_dl = DataLoader( + train_ds, + batch_size=args.batch_size, + collate_fn=collate_fn_with_template, + shuffle=True, + ) + val_dl = DataLoader( + val_ds, + batch_size=args.batch_size, + collate_fn=collate_fn_with_template, + ) -model = AutoModelForCausalLM.from_pretrained(args.ckpt, torch_dtype="bfloat16", device_map="auto") + trainer.to(device) + trainer.train() -trainer = LayerSkipModel(model=model) -optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) + train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args) -train_ds = load_dataset(args.ds_ckpt, split="train") -val_ds = load_dataset(args.ds_ckpt, split="eval") + if args.hub_id: + tokenizer.push_to_hub(args.hub_id) + trainer.model.push_to_hub(args.hub_id) -collate_fn_with_template = partial(collate_fn, template=args.template) -train_dl = DataLoader( - train_ds, - batch_size=args.batch_size, - collate_fn=collate_fn_with_template, - shuffle=True, -) -val_dl = DataLoader( - val_ds, - batch_size=args.batch_size, - collate_fn=collate_fn_with_template, -) -trainer.to(device) -trainer.train() +def process_cli_arguments(): + parser = HfArgumentParser((FineTuneArguments)) + args = parser.parse_args_into_dataclasses(return_remaining_strings=False) + return args -train_and_eval(train_dl, val_dl, tokenizer, device, trainer, optimizer, args) -if args.hub_id: - tokenizer.push_to_hub(args.hub_id) - trainer.model.push_to_hub(args.hub_id) +if __name__ == "__main__": + args = process_cli_arguments() + main(args)