Skip to content
This repository was archived by the owner on May 1, 2025. It is now read-only.

Commit 6f89b00

Browse files
committed
update
2 parents 587d6fb + 4d706f6 commit 6f89b00

24 files changed

+195
-61
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ conda activate codetf
8787

8888
2. Install from [PyPI](https://pypi.org/project/salesforce-codetf/):
8989
```bash
90-
pip install salesforce-codetf==1.0.2.1
90+
pip install salesforce-codetf
9191
```
9292

9393
3. Alternatively, build CodeTF from source:

codetf/common/__init__.py

Whitespace-only changes.

codetf/configs/inference/causal_lm.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ causallm-codegen-350M-multi-pretrained:
77
tokenizer_url: "Salesforce/codegen-350M-multi"
88
max_prediction_length: 512
99
causallm-codegen-350M-nl-pretrained:
10-
huggingface_url: "Salesforce/codegen-350-nl"
11-
tokenizer_url: "Salesforce/codegen-350-nl"
10+
huggingface_url: "Salesforce/codegen-350M-nl"
11+
tokenizer_url: "Salesforce/codegen-350M-nl"
1212
max_prediction_length: 512
1313
causallm-codegen-2B-mono-pretrained:
1414
huggingface_url: "Salesforce/codegen-2B-mono"

codetf/configs/inference/codet5.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,20 @@ codet5-plus-770M-python-pretrained:
140140
max_prediction_length: 512
141141
beam_size: 5
142142
trust_remote_code: False
143-
device_map: True
143+
device_map: False
144144
codet5-plus-770M-pretrained:
145145
huggingface_url: "Salesforce/codet5p-770m"
146146
tokenizer_url: "Salesforce/codet5p-770m"
147147
max_source_length: 512
148148
max_prediction_length: 512
149149
beam_size: 5
150150
trust_remote_code: False
151-
device_map: True
151+
device_map: False
152152
codet5-plus-220M-pretrained:
153153
huggingface_url: "Salesforce/codet5p-220m"
154154
tokenizer_url: "Salesforce/codet5p-220m"
155155
max_source_length: 512
156156
max_prediction_length: 512
157157
beam_size: 5
158158
trust_remote_code: False
159-
device_map: True
159+
device_map: False

codetf/configs/training/causal_lm.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ hyperparameters:
66
num_train_epochs: 10
77
auto_find_batch_size: True
88
batch_size: 4
9-
max_steps: 1000
10-
eval_steps: 100
11-
save_steps: 1000
129
logging_steps: 100
1310
per_device_train_batch_size: 8
1411
per_device_eval_batch_size: 8
@@ -30,3 +27,4 @@ hyperparameters:
3027
beam_size: 5
3128
max_grad_norm: 5.0
3229
adam_epsilon : 1e-06
30+
load_best_model_at_end: True

codetf/configs/training/codet5.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ hyperparameters:
66
num_train_epochs: 1
77
auto_find_batch_size: True
88
batch_size: 4
9-
max_steps: 1000
10-
eval_steps: 100
11-
save_steps: 1000
129
logging_steps: 100
1310
per_device_train_batch_size: 8
1411
per_device_eval_batch_size: 8
@@ -30,6 +27,7 @@ hyperparameters:
3027
beam_size: 5
3128
max_grad_norm: 5.0
3229
adam_epsilon : 1e-06
30+
load_best_model_at_end: True
3331
lora:
3432
r: 8
3533
lora_alpha: 32

codetf/data_utility/codexglue_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def load_codexglue_text_to_code_dataset(self, *args, **kwargs):
2525
dataset = load_dataset(dataset)
2626

2727
train = dataset["train"]
28+
train = train[:50]
2829
train_nl_tensors, _ = self.process_data(train["nl"])
2930
train_code_tensors, _ = self.process_data(train["code"])
3031

codetf/data_utility/human_eval_dataset.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ class HumanEvalDataset(BaseDataset):
99
def __init__(self, tokenizer, max_length=512):
1010
super().__init__(tokenizer, max_length)
1111

12+
def get_reference(self, task):
13+
"""Builds the reference solution for the doc (sample from the test dataset)."""
14+
test_func = task["test"]
15+
entry_point = f"check({task['entry_point']})"
16+
return "\n" + test_func + "\n" + entry_point
17+
1218
def load(self):
1319
dataset = self.dataset_config["openai_humaneval"]
1420

@@ -22,9 +28,10 @@ def load(self):
2228
# without strip, the model generates commented codes ...
2329
prompts.append(self.tokenizer.eos_token + dataset[task_index]["prompt"].strip())
2430

25-
unit_test = dataset[task_index]["test"]
26-
unit_test = re.sub(r'METADATA = {[^}]*}', '', unit_test, flags=re.MULTILINE)
27-
references.append(unit_test)
31+
# unit_test = dataset[task_index]["test"]
32+
# unit_test = re.sub(r'METADATA = {[^}]*}', '', unit_test, flags=re.MULTILINE)
33+
reference = self.get_reference(dataset[task_index])
34+
references.append(reference)
2835

2936
prompt_token_ids, prompt_attention_masks = self.process_data(prompts, padding="max_length")
3037

codetf/data_utility/stackexchange_instruction_dataset.py

Whitespace-only changes.

codetf/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ def load_model_pipeline(model_name, model_type="base", task="sum",
5151

5252
return model
5353

54+
def load_model_from_path(checkpoint_path, tokenizer_path, model_name, is_eval=True, load_in_8bit=False, load_in_4bit=False):
55+
model_cls = registry.get_model_class(model_name)
56+
model = model_cls.from_custom(checkpoint_path=checkpoint_path, tokenizer_path=tokenizer_path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit)
57+
if is_eval:
58+
model.eval()
59+
60+
return model
5461

5562
class ModelZoo:
5663
def __init__(self, config_files):

codetf/models/base_model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,17 @@ def from_pretrained(model_class, model_card, load_in_8bit=False, load_in_4bit=Fa
4545
Build a pretrained model from default configuration file, specified by model_type.
4646
"""
4747
model_config = OmegaConf.load(get_abs_path(model_class.MODEL_DICT))[model_card]
48-
model_cls = model_class.load_model_from_config(model_config=model_config, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, weight_sharding=weight_sharding)
48+
model_cls = model_class.load_huggingface_model_from_config(model_config=model_config, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, weight_sharding=weight_sharding)
49+
50+
return model_cls
51+
52+
53+
@classmethod
54+
def from_custom(model_class, checkpoint_path, tokenizer_path, load_in_8bit=False, load_in_4bit=False):
55+
"""
56+
Build a pretrained model from default configuration file, specified by model_type.
57+
"""
58+
model_cls = model_class.load_custom_model(checkpoint_path, tokenizer_path, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit)
4959

5060
return model_cls
5161

codetf/models/causal_lm_models/__init__.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def init_tokenizer(cls, model):
2929
return tokenizer
3030

3131
@classmethod
32-
def load_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
32+
def load_huggingface_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
3333
checkpoint = model_config["huggingface_url"]
3434

3535
if load_in_8bit and load_in_4bit:
@@ -79,6 +79,35 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
7979
model_config=model_config,
8080
tokenizer=tokenizer
8181
)
82+
83+
@classmethod
84+
def load_custom_model(model_class, checkpoint_path, tokenizer_path, load_in_8bit=False, load_in_4bit=False):
85+
86+
if load_in_8bit and load_in_4bit:
87+
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")
88+
89+
if load_in_8bit:
90+
model = AutoModelForCausalLM.from_pretrained(checkpoint_path,
91+
load_in_8bit=load_in_8bit,
92+
low_cpu_mem_usage=True,
93+
device_map="auto")
94+
elif load_in_4bit:
95+
model = AutoModelForCausalLM.from_pretrained(checkpoint_path,
96+
load_in_4bit=load_in_4bit,
97+
low_cpu_mem_usage=True,
98+
device_map="auto")
99+
else:
100+
model = AutoModelForCausalLM.from_pretrained(checkpoint_path,
101+
low_cpu_mem_usage=True,
102+
device_map="auto")
103+
104+
tokenizer = model_class.init_tokenizer(tokenizer_path)
105+
106+
return model_class(
107+
model=model,
108+
model_config=model_config,
109+
tokenizer=tokenizer
110+
)
82111

83112
def forward(self, sources, max_length=512):
84113
encoding = self.tokenizer(sources, return_tensors='pt').to(self.device)

codetf/models/seq2seq_models/__init__.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ def init_tokenizer(cls, model):
3030

3131

3232
@classmethod
33-
def load_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
33+
def load_huggingface_model_from_config(model_class, model_config, load_in_8bit=False, load_in_4bit=False, weight_sharding=False):
3434

3535
checkpoint = model_config["huggingface_url"]
3636

3737
if load_in_8bit and load_in_4bit:
3838
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")
3939

4040
# This "device" is for the case of CodeT5plus, will be removed in the future
41-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4242
if weight_sharding:
4343
try:
4444
# Try to download and load the json index file
@@ -85,12 +85,10 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
8585
else:
8686
if model_config["device_map"]:
8787
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
88-
load_in_4bit=load_in_4bit,
8988
low_cpu_mem_usage=True,
9089
device_map=model_config["device_map"], trust_remote_code=model_config["trust_remote_code"])
9190
else:
9291
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
93-
load_in_4bit=load_in_4bit,
9492
low_cpu_mem_usage=True,
9593
trust_remote_code=model_config["trust_remote_code"]).to(device)
9694

@@ -103,6 +101,35 @@ def load_model_from_config(model_class, model_config, load_in_8bit=False, load_i
103101
tokenizer=tokenizer
104102
)
105103

104+
@classmethod
105+
def load_custom_model(model_class, checkpoint_path, tokenizer_path, load_in_8bit=False, load_in_4bit=False):
106+
107+
if load_in_8bit and load_in_4bit:
108+
raise ValueError("Only one of load_in_8bit or load_in_4bit can be True. Please choose one.")
109+
110+
if load_in_8bit:
111+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path,
112+
load_in_8bit=load_in_8bit,
113+
low_cpu_mem_usage=True,
114+
device_map="auto")
115+
elif load_in_4bit:
116+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path,
117+
load_in_4bit=load_in_4bit,
118+
low_cpu_mem_usage=True,
119+
device_map="auto")
120+
else:
121+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path,
122+
low_cpu_mem_usage=True,
123+
device_map="auto")
124+
125+
tokenizer = model_class.init_tokenizer(tokenizer_path)
126+
127+
return model_class(
128+
model=model,
129+
model_config=model_config,
130+
tokenizer=tokenizer
131+
)
132+
106133

107134
def forward(self, sources, max_length=512, beam_size=5):
108135
encoding = self.tokenizer(sources, return_tensors='pt').to(self.model.device)

codetf/performance/model_evaluator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@ def __init__(self, model_class, num_workers=5):
2323

2424

2525
def evaluate_pass_k(self, problems, unit_tests, batch_size=1, max_length=600,
26-
top_p=0.95, k=[1,10,100],
26+
top_p=0.95, k=[1,10,100], temperature=1.2,
2727
num_return_sequences=200, sequences_per_chunk=10, num_workers=1):
2828
# Load dataset
2929
data_loader = DataLoader(problems, batch_size=batch_size)
3030
data_loader = self.accelerator.prepare(data_loader)
31-
31+
model_name = type(self.model_class).__name__
3232
# Initialize stopping criteria
3333
gen_kwargs = {
3434
"do_sample": True,
3535
"top_p": top_p,
36+
"temperature": temperature,
3637
"stopping_criteria": StoppingCriteriaList([EndOfFunctionCriteria(0, EOF_STRINGS, self.model_class.get_tokenizer())]),
3738
}
3839

@@ -54,7 +55,6 @@ def evaluate_pass_k(self, problems, unit_tests, batch_size=1, max_length=600,
5455
input_ids = prompt_ids[0, :attention_masks[0].sum().item()]
5556

5657
input_data = self.model_class.get_tokenizer().decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
57-
5858
batch_generated_ids = self.model_class.get_model().generate(
5959
input_ids=input_ids.unsqueeze(0),
6060
attention_mask=attention_masks[0, :attention_masks[0].sum().item()].unsqueeze(0),
@@ -66,14 +66,16 @@ def evaluate_pass_k(self, problems, unit_tests, batch_size=1, max_length=600,
6666
gen_codes = self.model_class.get_tokenizer().batch_decode(batch_generated_ids,
6767
skip_special_tokens=True, clean_up_tokenization_spaces=True)
6868

69-
for item in gen_codes:
70-
cleaned = remove_last_block(item)
71-
solutions_per_chunk.append(cleaned)
69+
for i,item in enumerate(gen_codes):
70+
result = remove_last_block(item)
71+
if model_name == "Seq2SeqModel":
72+
result = f"{input_data} {result}"
73+
74+
solutions_per_chunk.append(result)
7275

7376
solutions.append(solutions_per_chunk)
7477
dataloader_pbar.set_description(f"Processing step {step+1}/{len(data_loader)}")
7578

76-
7779
pass_at_k, _ = self.code_eval.compute(
7880
references=unit_tests, predictions=solutions, k=k, num_workers=num_workers
7981
)

codetf/trainer/base_trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def init_trainer(self):
5252

5353
def train(self):
5454
self.trainer.train()
55+
# self.trainer.save_model(self.checkpoints_path)
5556

5657
def evaluate(self, dataset=None):
5758
self.trainer.evaluate(dataset)
@@ -70,8 +71,10 @@ def get_default_codet5_hyperparameters(self):
7071
sharded_ddp=hyperparameters_config["sharded_ddp"],
7172
logging_steps=hyperparameters_config["logging_steps"],
7273
evaluation_strategy=hyperparameters_config["evaluation_strategy"],
74+
save_strategy=hyperparameters_config["save_strategy"],
7375
gradient_checkpointing=hyperparameters_config["gradient_checkpointing"],
7476
auto_find_batch_size=hyperparameters_config["auto_find_batch_size"],
77+
load_best_model_at_end=hyperparameters_config["load_best_model_at_end"],
7578
output_dir=self.checkpoints_path
7679
)
7780
# return hyperparameters_config
@@ -91,8 +94,10 @@ def get_default_causal_lm_hyperparameters(self):
9194
sharded_ddp=hyperparameters_config["sharded_ddp"],
9295
logging_steps=hyperparameters_config["logging_steps"],
9396
evaluation_strategy=hyperparameters_config["evaluation_strategy"],
97+
save_strategy=hyperparameters_config["save_strategy"],
9498
gradient_checkpointing=hyperparameters_config["gradient_checkpointing"],
9599
auto_find_batch_size=hyperparameters_config["auto_find_batch_size"],
100+
load_best_model_at_end=hyperparameters_config["load_best_model_at_end"],
96101
output_dir=self.checkpoints_path
97102
)
98103
# return hyperparameters_config

codetf/trainer/causal_lm_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, train_dataset, validation_dataset=None, tokenizer=None,
2121
self.trainer = self.init_trainer()
2222

2323
if peft:
24+
self.peft = peft
2425
self.model = prepare_model_for_int8_training(self.model)
2526
if peft == "lora":
2627
peft_config = self.get_default_lora_config_for_codet5()

codetf/trainer/codet5_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self, train_dataset, validation_dataset=None, tokenizer=None,
3838
self.trainer = self.init_trainer()
3939

4040
if peft:
41+
self.peft = peft
4142
self.model = prepare_model_for_int8_training(self.model)
4243
if peft == "lora":
4344
peft_config = self.get_default_lora_config_for_codet5()

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ torchvision==0.15.2
2020
tqdm==4.63.0
2121
transformers==4.30.2
2222
tree-sitter==0.20.1
23-
bitsandbytes==0.39.1
23+
bitsandbytes==0.39.1
24+
evaluate==0.4.0

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
"tqdm==4.63.0",
2323
"transformers==4.30.2",
2424
"tree-sitter==0.20.1",
25-
"bitsandbytes==0.39.1"
25+
"bitsandbytes==0.39.1",
26+
"evaluate==0.4.0"
2627
]
2728

2829
DEPENDENCY_LINKS = []
@@ -33,7 +34,7 @@
3334

3435
setup(
3536
name = 'salesforce-codetf',
36-
version = "1.0.2",
37+
version = "1.0.2.2",
3738
py_modules = ['codetf'],
3839
description = 'CodeTF: A Transformer-based Library for Code Intelligence',
3940
author = 'Nghi D. Q. Bui',

test_evaluation/test_evaluate_human_eval_codegen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
os.environ["TOKENIZERS_PARALLELISM"] = "true"
1212

1313
model_class = load_model_pipeline(model_name="causallm", task="pretrained",
14-
model_type="codegen-350M-mono", is_eval=True,
14+
model_type="codegen-350M-multi", is_eval=True,
1515
load_in_8bit=True, weight_sharding=False)
1616

1717
dataset = HumanEvalDataset(tokenizer=model_class.get_tokenizer())
@@ -20,6 +20,6 @@
2020
problems = TensorDataset(prompt_token_ids, prompt_attention_masks)
2121

2222
evaluator = ModelEvaluator(model_class)
23-
avg_pass_at_k = evaluator.evaluate_pass_k(problems=problems, unit_tests=references)
23+
avg_pass_at_k = evaluator.evaluate_pass_k(problems=problems, unit_tests=references, sequences_per_chunk=200, num_workers=5)
2424
print("Pass@k: ", avg_pass_at_k)
2525

0 commit comments

Comments
 (0)