Skip to content

Commit 1e3e506

Browse files
committed
init work on training pipeline
1 parent 5c31154 commit 1e3e506

File tree

1 file changed

+255
-0
lines changed

1 file changed

+255
-0
lines changed

scripts/train.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
#!/usr/bin/env python
2+
import fire
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
# import pandas as pd
7+
# import pathlib
8+
# from sklearn.model_selection import train_test_split
9+
# import sys
10+
import tensorflow as tf
11+
12+
# from tqdm.auto import tqdm
13+
14+
import tails.models
15+
from tails.utils import log
16+
17+
18+
def fnr_vs_fpr(predictions, ground_truth):
19+
rbbins = np.arange(-0.0001, 1.0001, 0.0001)
20+
h_b, e_b = np.histogram(predictions[ground_truth == 0], bins=rbbins, density=True)
21+
h_b_c = np.cumsum(h_b)
22+
h_r, e_r = np.histogram(predictions[ground_truth == 1], bins=rbbins, density=True)
23+
h_r_c = np.cumsum(h_r)
24+
25+
# h_b, e_b
26+
print(sum(ground_truth == 0), sum(ground_truth == 1))
27+
28+
fig = plt.figure(figsize=(9, 4), dpi=200)
29+
ax = fig.add_subplot(111)
30+
31+
rb_thres = np.array(list(range(len(h_b)))) / len(h_b)
32+
33+
ax.plot(
34+
rb_thres,
35+
h_r_c / np.max(h_r_c),
36+
label="False Negative Rate (FNR)",
37+
linewidth=1.5,
38+
)
39+
ax.plot(
40+
rb_thres,
41+
1 - h_b_c / np.max(h_b_c),
42+
label="False Positive Rate (FPR)",
43+
linewidth=1.5,
44+
)
45+
46+
mmce = (h_r_c / np.max(h_r_c) + 1 - h_b_c / np.max(h_b_c)) / 2
47+
ax.plot(
48+
rb_thres,
49+
mmce,
50+
"--",
51+
label="Mean misclassification error",
52+
color="gray",
53+
linewidth=1.5,
54+
)
55+
56+
ax.set_xlim([-0.05, 1.05])
57+
58+
ax.set_xticks(np.arange(0, 1.1, 0.1))
59+
ax.set_yticks(np.arange(0, 1.1, 0.1))
60+
61+
# vals = ax.get_yticks()
62+
# ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
63+
64+
ax.set_yscale("log")
65+
ax.set_ylim([5e-4, 1])
66+
vals = ax.get_yticks()
67+
ax.set_yticklabels(
68+
["{:,.1%}".format(x) if x < 0.01 else "{:,.0%}".format(x) for x in vals]
69+
)
70+
71+
# thresholds:
72+
# thrs = [0.5, ]
73+
thrs = [0.5, 0.7]
74+
for t in thrs:
75+
m_t = rb_thres < t
76+
fnr = np.array(h_r_c / np.max(h_r_c))[m_t][-1]
77+
fpr = np.array(1 - h_b_c / np.max(h_b_c))[m_t][-1]
78+
print(t, fnr * 100, fpr * 100)
79+
# ax.vlines(t_1, 0, 1.1)
80+
ax.vlines(t, 0, max(fnr, fpr))
81+
ax.text(
82+
t - 0.05,
83+
max(fnr, fpr) + 0.01,
84+
f" {fnr*100:.1f}% FNR\n {fpr*100:.1f}% FPR",
85+
fontsize=10,
86+
)
87+
88+
ax.set_xlabel("$p_c$ score threshold")
89+
ax.set_ylabel("Cumulative percentage")
90+
ax.legend(loc="upper center")
91+
ax.grid(True, which="major", linewidth=0.5)
92+
ax.grid(True, which="minor", linewidth=0.3)
93+
plt.tight_layout()
94+
plt.show()
95+
96+
97+
class TailsLoss(tf.keras.losses.BinaryCrossentropy):
98+
def __init__(self, w_1: float = 1, w_2: float = 1, **kwargs):
99+
super(TailsLoss, self).__init__(**kwargs)
100+
self.w_1 = w_1
101+
self.w_2 = w_2
102+
103+
def call(self, y_true, y_pred):
104+
output = tf.convert_to_tensor(y_pred[..., 0])
105+
target = tf.cast(y_true[..., 0], output.dtype)
106+
107+
# l_1: binary crossentropy for the label
108+
l_1 = super(TailsLoss, self).call(target, output)
109+
w_1 = tf.cast(self.w_1, output.dtype)
110+
l_1 = tf.math.multiply(l_1, w_1)
111+
112+
# l_2: L1 loss
113+
l_2 = tf.norm(y_pred[..., 1:] - y_true[..., 1:], ord=1)
114+
115+
# l_2: L1 loss + L2 regularization
116+
# l_2 = tf.norm(y_pred[..., 1:] - y_true[..., 1:], ord=1) + \
117+
# 1e-3 * tf.norm(y_pred[..., 1:] - y_true[..., 1:], ord=2)
118+
119+
l_2 = tf.math.multiply(l_2, target)
120+
l_2 = tf.math.divide(l_2, tf.math.reduce_sum(target))
121+
w_2 = tf.cast(self.w_2, output.dtype)
122+
l_2 = tf.math.multiply(l_2, w_2)
123+
124+
return l_1 + l_2
125+
126+
127+
class LabelAccuracy(tf.keras.metrics.Metric):
128+
def __init__(self, name="label_accuracy", threshold=0.5, **kwargs):
129+
super(LabelAccuracy, self).__init__(name=name, **kwargs)
130+
self.total = self.add_weight(name="total", initializer="zeros")
131+
self.count = self.add_weight(name="count", initializer="zeros")
132+
self.threshold = float(threshold)
133+
134+
def update_state(self, y_true, y_pred, sample_weight=None):
135+
output = y_pred[..., 0]
136+
# target = tf.cast(y_true[..., 0], output.dtype)
137+
target = tf.cast(y_true[..., 0], tf.bool)
138+
139+
threshold = tf.cast(0.5, output.dtype)
140+
output = tf.cast(output > threshold, tf.bool)
141+
142+
# values = tf.cast(tf.math.equal(target, output), output.dtype)
143+
values = tf.cast(tf.math.equal(target, output), tf.float32)
144+
ones = tf.cast(tf.math.equal(target, target), tf.float32)
145+
146+
if sample_weight is not None:
147+
sample_weight = tf.cast(sample_weight, self.dtype)
148+
sample_weight = tf.broadcast_weights(sample_weight, values)
149+
values = tf.multiply(values, sample_weight)
150+
151+
self.count.assign_add(tf.math.reduce_sum(values, axis=-1))
152+
self.total.assign_add(tf.math.reduce_sum(ones, axis=-1))
153+
154+
def result(self):
155+
return tf.math.divide(self.count, self.total)
156+
157+
158+
class PositionRootMeanSquarredError(tf.keras.metrics.Metric):
159+
def __init__(self, name="position_rmse", scaling_factor=1, **kwargs):
160+
super(PositionRootMeanSquarredError, self).__init__(name=name, **kwargs)
161+
self.total = self.add_weight(name="total", initializer="zeros")
162+
self.rmse = self.add_weight(name="rmse", initializer="zeros")
163+
self.scaling_factor = float(scaling_factor)
164+
165+
def update_state(self, y_true, y_pred, sample_weight=None):
166+
output = y_pred[..., 1:]
167+
target = tf.cast(y_true[..., 1:], output.dtype)
168+
label = tf.cast(y_true[..., 0], output.dtype)
169+
170+
rmse = tf.math.reduce_mean(
171+
tf.math.sqrt(tf.math.squared_difference(output, target)), axis=-1
172+
)
173+
# only take positive examples into account:
174+
rmse = tf.math.multiply(rmse, label)
175+
176+
self.rmse.assign_add(tf.math.reduce_sum(rmse, axis=-1))
177+
# only count the positive examples:
178+
self.total.assign_add(tf.math.reduce_sum(label, axis=-1))
179+
180+
def result(self):
181+
sf = tf.constant(self.scaling_factor, dtype=self.rmse.dtype.base_dtype)
182+
return tf.math.multiply(sf, tf.math.divide(self.rmse, self.total))
183+
184+
185+
def train_and_eval(
186+
train_dataset,
187+
val_dataset,
188+
test_dataset,
189+
steps_per_epoch_train,
190+
steps_per_epoch_val,
191+
epochs,
192+
class_weight,
193+
model_name: str = "tails",
194+
tag="20210101",
195+
w_1: float = 1.2,
196+
w_2: float = 1,
197+
class_threshold: float = 0.5,
198+
scaling_factor=256,
199+
input_shape=(256, 256, 3),
200+
weights: str = None,
201+
save_model=False,
202+
verbose=False,
203+
**kwargs,
204+
):
205+
classifier = tails.models.DNN(name=model_name)
206+
207+
tails_loss = TailsLoss(name="loss", w_1=w_1, w_2=w_2)
208+
label_accuracy = LabelAccuracy(threshold=class_threshold)
209+
# convert position RMSE to pixels
210+
position_rmse = PositionRootMeanSquarredError(scaling_factor=scaling_factor)
211+
212+
learning_rate = kwargs.get("learning_rate", 3e-4)
213+
patience = kwargs.get("patience", 30)
214+
215+
classifier.setup(
216+
input_shape=input_shape,
217+
n_output_neurons=3,
218+
architecture="tails",
219+
loss=tails_loss,
220+
optimizer="adam",
221+
lr=learning_rate, # epsilon=1e-3, beta_1=0.7,
222+
metrics=[label_accuracy, position_rmse],
223+
patience=patience,
224+
monitor="val_position_rmse",
225+
restore_best_weights=True,
226+
callbacks=("early_stopping", "tensorboard"),
227+
tag=tag,
228+
logdir="logs",
229+
)
230+
231+
# pre-load weights?
232+
if weights is not None:
233+
classifier.model.load_weights(weights)
234+
235+
classifier.train(
236+
train_dataset,
237+
val_dataset,
238+
steps_per_epoch_train,
239+
steps_per_epoch_val,
240+
epochs=epochs,
241+
class_weight=class_weight,
242+
verbose=True,
243+
)
244+
245+
# evaluate
246+
stats = classifier.evaluate(test_dataset)
247+
if verbose:
248+
log(stats)
249+
250+
if save_model:
251+
classifier.model.save_weights(f"{model_name}-{tag}")
252+
253+
254+
if __name__ == "__main__":
255+
fire.Fire(train_and_eval)

0 commit comments

Comments
 (0)