Skip to content

Commit 2ffa5bf

Browse files
committed
combing training script
1 parent b4263fe commit 2ffa5bf

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

scripts/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def train_and_eval(
204204
classifier = tails.models.DNN(name=model_name)
205205

206206
tails_loss = TailsLoss(name="loss", w_1=w_1, w_2=w_2)
207-
label_accuracy = LabelAccuracy(threshold=0.5)
207+
label_accuracy = LabelAccuracy(threshold=class_threshold)
208208
# convert position RMSE to pixels
209209
position_rmse = PositionRootMeanSquarredError(scaling_factor=scaling_factor)
210210

@@ -243,7 +243,8 @@ def train_and_eval(
243243
if verbose:
244244
log(stats)
245245

246-
# classifier.model.save_weights('tails')
246+
if save_model:
247+
classifier.model.save_weights(f"{model_name}-{tag}")
247248

248249

249250
if __name__ == "__main__":

0 commit comments

Comments
 (0)