|
| 1 | +#!/usr/bin/env python |
| 2 | +import fire |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +# import pandas as pd |
| 7 | +# import pathlib |
| 8 | +# from sklearn.model_selection import train_test_split |
| 9 | +# import sys |
| 10 | +import tensorflow as tf |
| 11 | + |
| 12 | +# from tqdm.auto import tqdm |
| 13 | + |
| 14 | +import tails.models |
| 15 | +from tails.utils import log |
| 16 | + |
| 17 | + |
| 18 | +def fnr_vs_fpr(predictions, ground_truth): |
| 19 | + rbbins = np.arange(-0.0001, 1.0001, 0.0001) |
| 20 | + h_b, e_b = np.histogram(predictions[ground_truth == 0], bins=rbbins, density=True) |
| 21 | + h_b_c = np.cumsum(h_b) |
| 22 | + h_r, e_r = np.histogram(predictions[ground_truth == 1], bins=rbbins, density=True) |
| 23 | + h_r_c = np.cumsum(h_r) |
| 24 | + |
| 25 | + # h_b, e_b |
| 26 | + print(sum(ground_truth == 0), sum(ground_truth == 1)) |
| 27 | + |
| 28 | + fig = plt.figure(figsize=(9, 4), dpi=200) |
| 29 | + ax = fig.add_subplot(111) |
| 30 | + |
| 31 | + rb_thres = np.array(list(range(len(h_b)))) / len(h_b) |
| 32 | + |
| 33 | + ax.plot( |
| 34 | + rb_thres, |
| 35 | + h_r_c / np.max(h_r_c), |
| 36 | + label="False Negative Rate (FNR)", |
| 37 | + linewidth=1.5, |
| 38 | + ) |
| 39 | + ax.plot( |
| 40 | + rb_thres, |
| 41 | + 1 - h_b_c / np.max(h_b_c), |
| 42 | + label="False Positive Rate (FPR)", |
| 43 | + linewidth=1.5, |
| 44 | + ) |
| 45 | + |
| 46 | + mmce = (h_r_c / np.max(h_r_c) + 1 - h_b_c / np.max(h_b_c)) / 2 |
| 47 | + ax.plot( |
| 48 | + rb_thres, |
| 49 | + mmce, |
| 50 | + "--", |
| 51 | + label="Mean misclassification error", |
| 52 | + color="gray", |
| 53 | + linewidth=1.5, |
| 54 | + ) |
| 55 | + |
| 56 | + ax.set_xlim([-0.05, 1.05]) |
| 57 | + |
| 58 | + ax.set_xticks(np.arange(0, 1.1, 0.1)) |
| 59 | + ax.set_yticks(np.arange(0, 1.1, 0.1)) |
| 60 | + |
| 61 | + # vals = ax.get_yticks() |
| 62 | + # ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals]) |
| 63 | + |
| 64 | + ax.set_yscale("log") |
| 65 | + ax.set_ylim([5e-4, 1]) |
| 66 | + vals = ax.get_yticks() |
| 67 | + ax.set_yticklabels( |
| 68 | + ["{:,.1%}".format(x) if x < 0.01 else "{:,.0%}".format(x) for x in vals] |
| 69 | + ) |
| 70 | + |
| 71 | + # thresholds: |
| 72 | + # thrs = [0.5, ] |
| 73 | + thrs = [0.5, 0.7] |
| 74 | + for t in thrs: |
| 75 | + m_t = rb_thres < t |
| 76 | + fnr = np.array(h_r_c / np.max(h_r_c))[m_t][-1] |
| 77 | + fpr = np.array(1 - h_b_c / np.max(h_b_c))[m_t][-1] |
| 78 | + print(t, fnr * 100, fpr * 100) |
| 79 | + # ax.vlines(t_1, 0, 1.1) |
| 80 | + ax.vlines(t, 0, max(fnr, fpr)) |
| 81 | + ax.text( |
| 82 | + t - 0.05, |
| 83 | + max(fnr, fpr) + 0.01, |
| 84 | + f" {fnr*100:.1f}% FNR\n {fpr*100:.1f}% FPR", |
| 85 | + fontsize=10, |
| 86 | + ) |
| 87 | + |
| 88 | + ax.set_xlabel("$p_c$ score threshold") |
| 89 | + ax.set_ylabel("Cumulative percentage") |
| 90 | + ax.legend(loc="upper center") |
| 91 | + ax.grid(True, which="major", linewidth=0.5) |
| 92 | + ax.grid(True, which="minor", linewidth=0.3) |
| 93 | + plt.tight_layout() |
| 94 | + plt.show() |
| 95 | + |
| 96 | + |
| 97 | +class TailsLoss(tf.keras.losses.BinaryCrossentropy): |
| 98 | + def __init__(self, w_1: float = 1, w_2: float = 1, **kwargs): |
| 99 | + super(TailsLoss, self).__init__(**kwargs) |
| 100 | + self.w_1 = w_1 |
| 101 | + self.w_2 = w_2 |
| 102 | + |
| 103 | + def call(self, y_true, y_pred): |
| 104 | + output = tf.convert_to_tensor(y_pred[..., 0]) |
| 105 | + target = tf.cast(y_true[..., 0], output.dtype) |
| 106 | + |
| 107 | + # l_1: binary crossentropy for the label |
| 108 | + l_1 = super(TailsLoss, self).call(target, output) |
| 109 | + w_1 = tf.cast(self.w_1, output.dtype) |
| 110 | + l_1 = tf.math.multiply(l_1, w_1) |
| 111 | + |
| 112 | + # l_2: L1 loss |
| 113 | + l_2 = tf.norm(y_pred[..., 1:] - y_true[..., 1:], ord=1) |
| 114 | + |
| 115 | + # l_2: L1 loss + L2 regularization |
| 116 | + # l_2 = tf.norm(y_pred[..., 1:] - y_true[..., 1:], ord=1) + \ |
| 117 | + # 1e-3 * tf.norm(y_pred[..., 1:] - y_true[..., 1:], ord=2) |
| 118 | + |
| 119 | + l_2 = tf.math.multiply(l_2, target) |
| 120 | + l_2 = tf.math.divide(l_2, tf.math.reduce_sum(target)) |
| 121 | + w_2 = tf.cast(self.w_2, output.dtype) |
| 122 | + l_2 = tf.math.multiply(l_2, w_2) |
| 123 | + |
| 124 | + return l_1 + l_2 |
| 125 | + |
| 126 | + |
| 127 | +class LabelAccuracy(tf.keras.metrics.Metric): |
| 128 | + def __init__(self, name="label_accuracy", threshold=0.5, **kwargs): |
| 129 | + super(LabelAccuracy, self).__init__(name=name, **kwargs) |
| 130 | + self.total = self.add_weight(name="total", initializer="zeros") |
| 131 | + self.count = self.add_weight(name="count", initializer="zeros") |
| 132 | + self.threshold = float(threshold) |
| 133 | + |
| 134 | + def update_state(self, y_true, y_pred, sample_weight=None): |
| 135 | + output = y_pred[..., 0] |
| 136 | + # target = tf.cast(y_true[..., 0], output.dtype) |
| 137 | + target = tf.cast(y_true[..., 0], tf.bool) |
| 138 | + |
| 139 | + threshold = tf.cast(0.5, output.dtype) |
| 140 | + output = tf.cast(output > threshold, tf.bool) |
| 141 | + |
| 142 | + # values = tf.cast(tf.math.equal(target, output), output.dtype) |
| 143 | + values = tf.cast(tf.math.equal(target, output), tf.float32) |
| 144 | + ones = tf.cast(tf.math.equal(target, target), tf.float32) |
| 145 | + |
| 146 | + if sample_weight is not None: |
| 147 | + sample_weight = tf.cast(sample_weight, self.dtype) |
| 148 | + sample_weight = tf.broadcast_weights(sample_weight, values) |
| 149 | + values = tf.multiply(values, sample_weight) |
| 150 | + |
| 151 | + self.count.assign_add(tf.math.reduce_sum(values, axis=-1)) |
| 152 | + self.total.assign_add(tf.math.reduce_sum(ones, axis=-1)) |
| 153 | + |
| 154 | + def result(self): |
| 155 | + return tf.math.divide(self.count, self.total) |
| 156 | + |
| 157 | + |
| 158 | +class PositionRootMeanSquarredError(tf.keras.metrics.Metric): |
| 159 | + def __init__(self, name="position_rmse", scaling_factor=1, **kwargs): |
| 160 | + super(PositionRootMeanSquarredError, self).__init__(name=name, **kwargs) |
| 161 | + self.total = self.add_weight(name="total", initializer="zeros") |
| 162 | + self.rmse = self.add_weight(name="rmse", initializer="zeros") |
| 163 | + self.scaling_factor = float(scaling_factor) |
| 164 | + |
| 165 | + def update_state(self, y_true, y_pred, sample_weight=None): |
| 166 | + output = y_pred[..., 1:] |
| 167 | + target = tf.cast(y_true[..., 1:], output.dtype) |
| 168 | + label = tf.cast(y_true[..., 0], output.dtype) |
| 169 | + |
| 170 | + rmse = tf.math.reduce_mean( |
| 171 | + tf.math.sqrt(tf.math.squared_difference(output, target)), axis=-1 |
| 172 | + ) |
| 173 | + # only take positive examples into account: |
| 174 | + rmse = tf.math.multiply(rmse, label) |
| 175 | + |
| 176 | + self.rmse.assign_add(tf.math.reduce_sum(rmse, axis=-1)) |
| 177 | + # only count the positive examples: |
| 178 | + self.total.assign_add(tf.math.reduce_sum(label, axis=-1)) |
| 179 | + |
| 180 | + def result(self): |
| 181 | + sf = tf.constant(self.scaling_factor, dtype=self.rmse.dtype.base_dtype) |
| 182 | + return tf.math.multiply(sf, tf.math.divide(self.rmse, self.total)) |
| 183 | + |
| 184 | + |
| 185 | +def train_and_eval( |
| 186 | + train_dataset, |
| 187 | + val_dataset, |
| 188 | + test_dataset, |
| 189 | + steps_per_epoch_train, |
| 190 | + steps_per_epoch_val, |
| 191 | + epochs, |
| 192 | + class_weight, |
| 193 | + model_name: str = "tails", |
| 194 | + tag="20210101", |
| 195 | + w_1: float = 1.2, |
| 196 | + w_2: float = 1, |
| 197 | + class_threshold: float = 0.5, |
| 198 | + scaling_factor=256, |
| 199 | + input_shape=(256, 256, 3), |
| 200 | + weights: str = None, |
| 201 | + save_model=False, |
| 202 | + verbose=False, |
| 203 | + **kwargs, |
| 204 | +): |
| 205 | + classifier = tails.models.DNN(name=model_name) |
| 206 | + |
| 207 | + tails_loss = TailsLoss(name="loss", w_1=w_1, w_2=w_2) |
| 208 | + label_accuracy = LabelAccuracy(threshold=class_threshold) |
| 209 | + # convert position RMSE to pixels |
| 210 | + position_rmse = PositionRootMeanSquarredError(scaling_factor=scaling_factor) |
| 211 | + |
| 212 | + learning_rate = kwargs.get("learning_rate", 3e-4) |
| 213 | + patience = kwargs.get("patience", 30) |
| 214 | + |
| 215 | + classifier.setup( |
| 216 | + input_shape=input_shape, |
| 217 | + n_output_neurons=3, |
| 218 | + architecture="tails", |
| 219 | + loss=tails_loss, |
| 220 | + optimizer="adam", |
| 221 | + lr=learning_rate, # epsilon=1e-3, beta_1=0.7, |
| 222 | + metrics=[label_accuracy, position_rmse], |
| 223 | + patience=patience, |
| 224 | + monitor="val_position_rmse", |
| 225 | + restore_best_weights=True, |
| 226 | + callbacks=("early_stopping", "tensorboard"), |
| 227 | + tag=tag, |
| 228 | + logdir="logs", |
| 229 | + ) |
| 230 | + |
| 231 | + # pre-load weights? |
| 232 | + if weights is not None: |
| 233 | + classifier.model.load_weights(weights) |
| 234 | + |
| 235 | + classifier.train( |
| 236 | + train_dataset, |
| 237 | + val_dataset, |
| 238 | + steps_per_epoch_train, |
| 239 | + steps_per_epoch_val, |
| 240 | + epochs=epochs, |
| 241 | + class_weight=class_weight, |
| 242 | + verbose=True, |
| 243 | + ) |
| 244 | + |
| 245 | + # evaluate |
| 246 | + stats = classifier.evaluate(test_dataset) |
| 247 | + if verbose: |
| 248 | + log(stats) |
| 249 | + |
| 250 | + if save_model: |
| 251 | + classifier.model.save_weights(f"{model_name}-{tag}") |
| 252 | + |
| 253 | + |
| 254 | +if __name__ == "__main__": |
| 255 | + fire.Fire(train_and_eval) |
0 commit comments