33import  argparse 
44import  concurrent .futures  as  cf 
55import  contextlib 
6+ from  pathlib  import  Path 
67from  random  import  randint 
78from  statistics  import  mean , stdev 
9+ from  uuid  import  uuid4 
810
911import  pycrfsuite 
1012from  sklearn .model_selection  import  train_test_split 
2527def  train_parser_model (
2628    vectors : DataVectors ,
2729    split : float ,
28-     save_model : str ,
30+     save_model : Path ,
2931    seed : int  |  None ,
3032    html : bool ,
3133    detailed_results : bool ,
3234    plot_confusion_matrix : bool ,
35+     keep_model : bool  =  True ,
3336) ->  Stats :
3437    """Train model using vectors, splitting the vectors into a train and evaluation 
3538    set based on <split>. The trained model is saved to <save_model>. 
@@ -40,7 +43,7 @@ def train_parser_model(
4043        Vectors loaded from training csv files 
4144    split : float 
4245        Fraction of vectors to use for evaluation. 
43-     save_model : str  
46+     save_model : Path  
4447        Path to save trained model to. 
4548    seed : int | None 
4649        Integer used as seed for splitting the vectors between the training and 
@@ -53,6 +56,9 @@ def train_parser_model(
5356        the test set. 
5457    plot_confusion_matrix : bool 
5558        If True, plot a confusion matrix of the token labels. 
59+     kee[_model : bool, optional 
60+         If False, delete model from disk after evaluating it's performance. 
61+         Default is True. 
5662
5763    Returns 
5864    ------- 
@@ -109,11 +115,11 @@ def train_parser_model(
109115    )
110116    for  X , y  in  zip (features_train , truth_train ):
111117        trainer .append (X , y )
112-     trainer .train (save_model )
118+     trainer .train (str ( save_model ) )
113119
114120    print ("[INFO] Evaluating model with test data." )
115121    tagger  =  pycrfsuite .Tagger ()  # type: ignore 
116-     tagger .open (save_model )
122+     tagger .open (str ( save_model ) )
117123
118124    labels_pred , scores_pred  =  [], []
119125    for  X  in  features_test :
@@ -146,6 +152,10 @@ def train_parser_model(
146152        confusion_matrix (labels_pred , truth_test )
147153
148154    stats  =  evaluate (labels_pred , truth_test , seed )
155+ 
156+     if  not  keep_model :
157+         save_model .unlink (missing_ok = True )
158+ 
149159    return  stats 
150160
151161
@@ -167,11 +177,12 @@ def train_single(args: argparse.Namespace) -> None:
167177    stats  =  train_parser_model (
168178        vectors ,
169179        args .split ,
170-         save_model ,
180+         Path ( save_model ) ,
171181        args .seed ,
172182        args .html ,
173183        args .detailed ,
174184        args .confusion ,
185+         keep_model = True ,
175186    )
176187
177188    print ("Sentence-level results:" )
@@ -201,19 +212,22 @@ def train_multiple(args: argparse.Namespace) -> None:
201212    else :
202213        save_model  =  args .save_model 
203214
204-     # The first None argument is for the seed. This is set to None so each 
205-     # iteration of the training function uses a different random seed. 
206-     arguments  =  [
207-         (
208-             vectors ,
209-             args .split ,
210-             save_model ,
211-             None ,
212-             args .html ,
213-             args .detailed ,
214-             args .confusion ,
215+     arguments  =  []
216+     for  _  in  range (args .runs ):
217+         # The first None argument is for the seed. This is set to None so each 
218+         # iteration of the training function uses a different random seed. 
219+         arguments .append (
220+             (
221+                 vectors ,
222+                 args .split ,
223+                 Path (save_model ).with_stem ("model-"  +  str (uuid4 ())),
224+                 None ,  # Seed 
225+                 args .html ,
226+                 args .detailed ,
227+                 args .confusion ,
228+                 False ,  # keep_model 
229+             )
215230        )
216-     ] *  args .runs 
217231
218232    eval_results  =  []
219233    with  contextlib .redirect_stdout (None ):  # Suppress print output 
0 commit comments