Skip to content

Commit 09b0c28

Browse files
committed
Feature: When training multiple models, give each one a unique file and delete after evaluating so that the existing model is not modified.
1 parent 53678aa commit 09b0c28

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

train/train_model.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import argparse
44
import concurrent.futures as cf
55
import contextlib
6+
from pathlib import Path
67
from random import randint
78
from statistics import mean, stdev
9+
from uuid import uuid4
810

911
import pycrfsuite
1012
from sklearn.model_selection import train_test_split
@@ -25,11 +27,12 @@
2527
def train_parser_model(
2628
vectors: DataVectors,
2729
split: float,
28-
save_model: str,
30+
save_model: Path,
2931
seed: int | None,
3032
html: bool,
3133
detailed_results: bool,
3234
plot_confusion_matrix: bool,
35+
keep_model: bool = True,
3336
) -> Stats:
3437
"""Train model using vectors, splitting the vectors into a train and evaluation
3538
set based on <split>. The trained model is saved to <save_model>.
@@ -40,7 +43,7 @@ def train_parser_model(
4043
Vectors loaded from training csv files
4144
split : float
4245
Fraction of vectors to use for evaluation.
43-
save_model : str
46+
save_model : Path
4447
Path to save trained model to.
4548
seed : int | None
4649
Integer used as seed for splitting the vectors between the training and
@@ -53,6 +56,9 @@ def train_parser_model(
5356
the test set.
5457
plot_confusion_matrix : bool
5558
If True, plot a confusion matrix of the token labels.
59+
kee[_model : bool, optional
60+
If False, delete model from disk after evaluating it's performance.
61+
Default is True.
5662
5763
Returns
5864
-------
@@ -109,11 +115,11 @@ def train_parser_model(
109115
)
110116
for X, y in zip(features_train, truth_train):
111117
trainer.append(X, y)
112-
trainer.train(save_model)
118+
trainer.train(str(save_model))
113119

114120
print("[INFO] Evaluating model with test data.")
115121
tagger = pycrfsuite.Tagger() # type: ignore
116-
tagger.open(save_model)
122+
tagger.open(str(save_model))
117123

118124
labels_pred, scores_pred = [], []
119125
for X in features_test:
@@ -146,6 +152,10 @@ def train_parser_model(
146152
confusion_matrix(labels_pred, truth_test)
147153

148154
stats = evaluate(labels_pred, truth_test, seed)
155+
156+
if not keep_model:
157+
save_model.unlink(missing_ok=True)
158+
149159
return stats
150160

151161

@@ -167,11 +177,12 @@ def train_single(args: argparse.Namespace) -> None:
167177
stats = train_parser_model(
168178
vectors,
169179
args.split,
170-
save_model,
180+
Path(save_model),
171181
args.seed,
172182
args.html,
173183
args.detailed,
174184
args.confusion,
185+
keep_model=True,
175186
)
176187

177188
print("Sentence-level results:")
@@ -201,19 +212,22 @@ def train_multiple(args: argparse.Namespace) -> None:
201212
else:
202213
save_model = args.save_model
203214

204-
# The first None argument is for the seed. This is set to None so each
205-
# iteration of the training function uses a different random seed.
206-
arguments = [
207-
(
208-
vectors,
209-
args.split,
210-
save_model,
211-
None,
212-
args.html,
213-
args.detailed,
214-
args.confusion,
215+
arguments = []
216+
for _ in range(args.runs):
217+
# The first None argument is for the seed. This is set to None so each
218+
# iteration of the training function uses a different random seed.
219+
arguments.append(
220+
(
221+
vectors,
222+
args.split,
223+
Path(save_model).with_stem("model-" + str(uuid4())),
224+
None, # Seed
225+
args.html,
226+
args.detailed,
227+
args.confusion,
228+
False, # keep_model
229+
)
215230
)
216-
] * args.runs
217231

218232
eval_results = []
219233
with contextlib.redirect_stdout(None): # Suppress print output

0 commit comments

Comments
 (0)