We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b4263fe commit 2ffa5bfCopy full SHA for 2ffa5bf
scripts/train.py
@@ -204,7 +204,7 @@ def train_and_eval(
204
classifier = tails.models.DNN(name=model_name)
205
206
tails_loss = TailsLoss(name="loss", w_1=w_1, w_2=w_2)
207
- label_accuracy = LabelAccuracy(threshold=0.5)
+ label_accuracy = LabelAccuracy(threshold=class_threshold)
208
# convert position RMSE to pixels
209
position_rmse = PositionRootMeanSquarredError(scaling_factor=scaling_factor)
210
@@ -243,7 +243,8 @@ def train_and_eval(
243
if verbose:
244
log(stats)
245
246
- # classifier.model.save_weights('tails')
+ if save_model:
247
+ classifier.model.save_weights(f"{model_name}-{tag}")
248
249
250
if __name__ == "__main__":
0 commit comments