Skip to content

Commit 9742120

Browse files
authored
Merge pull request #119 from amosproj/98-implement-a-script-for-llm-fine-tuning
98 implement a script for llm fine tuning
2 parents ce9477f + ce79baf commit 9742120

File tree

3 files changed

+121
-95
lines changed

3 files changed

+121
-95
lines changed

src/hpc_scripts/CustomSFTTrainer.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import List, Union
2+
from trl import SFTTrainer
3+
import optuna
4+
from transformers.trainer_utils import HPSearchBackend, BestRun, PREFIX_CHECKPOINT_DIR, default_compute_objective
5+
import os
6+
import gc
7+
import torch
8+
9+
10+
class CustomSFTTrainer(SFTTrainer):
11+
12+
@staticmethod
13+
def run_hp_search_optuna(trainer, n_trials, direction, **kwargs):
14+
15+
def _objective(trial, checkpoint_dir=None):
16+
checkpoint = None
17+
if checkpoint_dir:
18+
for subdir in os.listdir(checkpoint_dir):
19+
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
20+
checkpoint = os.path.join(checkpoint_dir, subdir)
21+
#################
22+
# UPDATES START
23+
#################
24+
if not checkpoint:
25+
# free GPU memory
26+
del trainer.model
27+
gc.collect()
28+
torch.cuda.empty_cache()
29+
trainer.objective = None
30+
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
31+
# If there hasn't been any evaluation during the training loop.
32+
if getattr(trainer, "objective", None) is None:
33+
metrics = trainer.evaluate()
34+
trainer.objective = trainer.compute_objective(metrics)
35+
return trainer.objective
36+
37+
timeout = kwargs.pop("timeout", None)
38+
n_jobs = kwargs.pop("n_jobs", 1)
39+
study = optuna.create_study(direction=direction, **kwargs)
40+
study.optimize(_objective, n_trials=n_trials,
41+
timeout=timeout, n_jobs=n_jobs)
42+
best_trial = study.best_trial
43+
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
44+
45+
def hyperparameter_search(
46+
self,
47+
hp_space,
48+
n_trials,
49+
direction,
50+
compute_objective=default_compute_objective,
51+
) -> Union[BestRun, List[BestRun]]:
52+
53+
self.hp_search_backend = HPSearchBackend.OPTUNA
54+
self.hp_space = hp_space
55+
self.hp_name = None
56+
self.compute_objective = compute_objective
57+
best_run = CustomSFTTrainer.run_hp_search_optuna(
58+
self, n_trials, direction)
59+
self.hp_search_backend = None
60+
return best_run
Lines changed: 41 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,124 +1,78 @@
11
# imports
2-
import transformers
2+
import gc
33
from transformers import (AutoModelForCausalLM,
44
AutoTokenizer,
55
TrainingArguments,
6+
BitsAndBytesConfig
67
)
7-
from trl import SFTTrainer
8-
from peft import LoraConfig
8+
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
99
from datasets import load_dataset
10-
from transformers import AutoTokenizer, AutoModelForCausalLM
1110
from huggingface_hub import HfApi, login
12-
from transformers.hyperparameter_search import HPSearchBackend
13-
from transformers.trainer import *
14-
import optuna
15-
import gc
16-
11+
import torch
12+
import CustomSFTTrainer
13+
import random
1714
import os
1815
HF_TOKEN = os.getenv('HF_TOKEN', 'add_hf_token')
1916
api = HfApi()
2017
login(HF_TOKEN, add_to_git_credential=True)
2118

22-
2319
gc.collect()
2420
torch.cuda.empty_cache()
25-
26-
27-
def run_hp_search_optuna(trainer, n_trials, direction, **kwargs):
28-
29-
def _objective(trial, checkpoint_dir=None):
30-
checkpoint = None
31-
if checkpoint_dir:
32-
for subdir in os.listdir(checkpoint_dir):
33-
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
34-
checkpoint = os.path.join(checkpoint_dir, subdir)
35-
#################
36-
# UPDATES START
37-
#################
38-
if not checkpoint:
39-
# free GPU memory
40-
del trainer.model
41-
gc.collect()
42-
torch.cuda.empty_cache()
43-
trainer.objective = None
44-
trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
45-
# If there hasn't been any evaluation during the training loop.
46-
if getattr(trainer, "objective", None) is None:
47-
metrics = trainer.evaluate()
48-
trainer.objective = trainer.compute_objective(metrics)
49-
return trainer.objective
50-
51-
timeout = kwargs.pop("timeout", None)
52-
n_jobs = kwargs.pop("n_jobs", 1)
53-
study = optuna.create_study(direction=direction, **kwargs)
54-
study.optimize(_objective, n_trials=n_trials,
55-
timeout=timeout, n_jobs=n_jobs)
56-
best_trial = study.best_trial
57-
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
58-
59-
60-
def hyperparameter_search(
61-
self,
62-
hp_space,
63-
n_trials,
64-
direction,
65-
compute_objective=default_compute_objective,
66-
) -> Union[BestRun, List[BestRun]]:
67-
68-
trainer.hp_search_backend = HPSearchBackend.OPTUNA
69-
self.hp_space = hp_space
70-
trainer.hp_name = None
71-
trainer.compute_objective = compute_objective
72-
best_run = run_hp_search_optuna(trainer, n_trials, direction)
73-
self.hp_search_backend = None
74-
return best_run
75-
76-
77-
transformers.trainer.Trainer.hyperparameter_search = hyperparameter_search
78-
79-
8021
# defining hyperparameter search space for optuna
8122

8223

8324
def optuna_hp_space(trial):
8425
return {
8526
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
86-
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64]),
8727
"num_train_epochs": trial.suggest_int("num_train_epochs", 3, 15),
8828
"weight_decay": trial.suggest_loguniform("weight_decay", 1e-6, 1e-2),
89-
"gradient_clipping": trial.suggest_float("gradient_clipping", 0.1, 0.5),
9029
}
9130

9231
# Define a function to calculate BLEU score
9332

9433

9534
# configuration arguments
96-
model_id = "google/gemma-2-27b-it"
35+
model_id = "google/gemma-2-9b-it"
9736

98-
# model init function for the trainer
37+
# bits and bytes config
38+
bnb_config = BitsAndBytesConfig(
39+
load_in_4bit=True,
40+
bnb_4bit_quant_type="nf4",
41+
bnb_4bit_compute_dtype=torch.bfloat16
42+
)
9943

10044

10145
def model_init(trial):
102-
103-
return AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
46+
model = AutoModelForCausalLM.from_pretrained(
47+
model_id, quantization_config=bnb_config, device_map="auto")
48+
model = prepare_model_for_kbit_training(model)
49+
model = get_peft_model(model, lora_config)
50+
return model
10451

10552

10653
# tokenizer load
10754
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='right')
10855

109-
# Loading training and evaluation data
110-
training_dataset = load_dataset(
111-
"Kubermatic/cncf-question-and-answer-dataset-for-llm-training", split="train[:7500]")
112-
eval_dataset = load_dataset(
113-
"Kubermatic/cncf-question-and-answer-dataset-for-llm-training", split="train[7500:8000]")
56+
dataset = load_dataset(
57+
"Kubermatic/Merged_QAs", split="train")
58+
59+
random.seed(42)
60+
random_indices = random.sample(range(len(dataset)), k=500)
61+
62+
training_indices = random_indices[:400]
63+
eval_indices = random_indices[400:500]
64+
training_dataset = dataset.filter(
65+
lambda _, idx: idx in training_indices, with_indices=True)
66+
eval_dataset = dataset.filter(
67+
lambda _, idx: idx in eval_indices, with_indices=True)
11468

11569
max_seq_length = 1024
11670

11771

11872
output_dir = "trained_model"
11973
training_arguments = TrainingArguments(
12074
output_dir=output_dir,
121-
num_train_epochs=1,
75+
num_train_epochs=3,
12276
gradient_checkpointing=True,
12377
per_device_train_batch_size=1,
12478
gradient_accumulation_steps=8,
@@ -163,11 +117,14 @@ def formatting_func(example):
163117
output_texts.append(text)
164118
return output_texts
165119

166-
# instantiation of the trainer
120+
121+
# Passing model
122+
model = model_init(None)
167123

168124

169-
trainer = SFTTrainer(
170-
model=model_id,
125+
# instantiation of the trainer
126+
trainer = CustomSFTTrainer(
127+
model=model,
171128
train_dataset=training_dataset,
172129
eval_dataset=eval_dataset,
173130
args=training_arguments,
@@ -178,10 +135,13 @@ def formatting_func(example):
178135
model_init=model_init,
179136
)
180137

138+
# avoid placing model on device as it is already placed on device in model_init
139+
trainer.place_model_on_device = False
140+
181141
best_trial = trainer.hyperparameter_search(
182142
direction="minimize",
183143
hp_space=optuna_hp_space,
184-
n_trials=20,
144+
n_trials=5,
185145
)
186146

187147
print(best_trial)

src/hpc_scripts/model_training.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,23 @@
1717

1818

1919
# training pipeline taken from https://huggingface.co/blog/gemma-peft
20-
model_id = "google/gemma-2-27b-it"
20+
model_id = "google/gemma-2-9b-it"
2121

2222
bnb_config = BitsAndBytesConfig(
23-
load_in_8bit=True,
24-
bnb_8bit_quant_type="nf4",
25-
bnb_8bit_compute_dtype=torch.bfloat16
23+
load_in_4bit=True,
24+
bnb_4bit_quant_type="nf4",
25+
bnb_4bit_compute_dtype=torch.bfloat16
2626
)
2727

28+
dataset = load_dataset(
29+
"Kubermatic/Merged_QAs", split="train")
30+
dataset.shuffle(42)
31+
dataset = dataset.train_test_split(train_size=0.20, test_size=0.04)
32+
2833
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='right')
2934
# TODO: Check if this can be changed to AutoModelForQuestionAnswering with GEMMA
3035
model = AutoModelForCausalLM.from_pretrained(
31-
model_id, quantization_config=bnb_config, device_map="auto")
32-
33-
# Training Data
34-
dataset = load_dataset(
35-
"Kubermatic/cncf-question-and-answer-dataset-for-llm-training", split="train")
36+
model_id, quantization_config=bnb_config, device_map="auto", attn_implementation='eager')
3637

3738

3839
# Training (hyper)parameters (initial config taken from: https://medium.com/@lucamassaron/sherlock-holmes-q-a-enhanced-with-gemma-2b-it-fine-tuning-2907b06d2645)
@@ -44,15 +45,15 @@
4445

4546
training_arguments = TrainingArguments(
4647
output_dir=output_dir,
47-
num_train_epochs=3,
48+
num_train_epochs=5,
4849
gradient_checkpointing=True,
49-
per_device_train_batch_size=16,
50+
per_device_train_batch_size=4,
5051
gradient_accumulation_steps=8,
5152
optim="paged_adamw_32bit",
5253
save_steps=0,
5354
logging_steps=10,
54-
learning_rate=5e-4,
55-
weight_decay=0.001,
55+
learning_rate=1.344609154868106e-05,
56+
weight_decay=0.00019307024914471071,
5657
fp16=True,
5758
bf16=False,
5859
max_grad_norm=0.3,
@@ -63,6 +64,10 @@
6364
report_to="tensorboard",
6465
disable_tqdm=False,
6566
load_best_model_at_end=True,
67+
eval_accumulation_steps=1,
68+
evaluation_strategy='steps',
69+
eval_steps=500,
70+
per_device_eval_batch_size=4
6671
# debug="underflow_overflow"
6772
)
6873

@@ -96,13 +101,14 @@ def formatting_func(example):
96101

97102
trainer = SFTTrainer(
98103
model=model,
99-
train_dataset=dataset,
104+
train_dataset=dataset["train"],
100105
args=training_arguments,
101106
peft_config=lora_config,
102107
formatting_func=formatting_func,
103108
tokenizer=tokenizer,
104109
max_seq_length=max_seq_length,
105110
callbacks=[EarlyStoppingCallback(early_stopping_patience=15)],
111+
eval_dataset=dataset["test"],
106112
)
107113
trainer.train()
108114
print("Model is trained")

0 commit comments

Comments
 (0)