Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
76efd8b
fix: epoch always started with disc_loss = 0, resulting in disc never…
bnb32 Feb 5, 2025
96b151a
only use output features in content loss
bnb32 Feb 5, 2025
a722236
save / load check for disc training fix
bnb32 Feb 5, 2025
2808f1f
use n_obs from previous epoch for loss running mean
bnb32 Feb 5, 2025
55774ee
content loss edit to account for undefined hr_out_features
bnb32 Feb 6, 2025
939897f
adding gen loss from previous epoch for running means
bnb32 Feb 6, 2025
58bf727
Using running dataframe records of training and validation batch loss…
bnb32 Feb 6, 2025
b49456f
Initialize records from history for loaded models
bnb32 Feb 6, 2025
6777b97
use index to append loss details record
bnb32 Feb 7, 2025
6069637
changed trained_frac naming - doesn't make sense to prefix these with…
bnb32 Feb 7, 2025
cdd1edf
missed new method `get_hr_exo_and_loss` from cherry picks
bnb32 Feb 7, 2025
acb04b2
added `loss_mean_window` arg, material derivative loss with extremes,…
bnb32 Feb 12, 2025
b03a539
fix: wasn't updating running mean `loss_details`
bnb32 Feb 12, 2025
41ce39b
material derivative loss test fix
bnb32 Feb 12, 2025
f26dc86
typo
bnb32 Feb 12, 2025
cee57bd
use epoch mean in history but window for running mean
bnb32 Feb 13, 2025
9d6a51d
history indexing catch
bnb32 Feb 17, 2025
71c6ef1
floating point error adjustment
bnb32 Feb 17, 2025
2cb320b
renaming `_get_batch_loss_details`. moved weight into to start of `tr…
bnb32 Feb 19, 2025
7029740
Moving weight initialization to start of training. Use batch_handler.…
bnb32 Feb 19, 2025
cb57ef3
Using queue_shape doesn't work since some queues only include high re…
bnb32 Feb 19, 2025
48cbd5c
Removed `_get_hr_exo_and_loss`. Renamed `_run_gradient_descent` and `…
bnb32 Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 53 additions & 39 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from warnings import warn

import numpy as np
import pandas as pd
import tensorflow as tf
from phygnn import CustomNetwork
from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat
Expand Down Expand Up @@ -621,6 +622,8 @@ def __init__(self):
self._gen = None
self._means = None
self._stdevs = None
self._train_record = pd.DataFrame()
self._val_record = pd.DataFrame()

def load_network(self, model, name):
"""Load a CustomNetwork object from hidden layers config, .json file
Expand Down Expand Up @@ -969,6 +972,17 @@ def load_saved_params(out_dir, verbose=True):

return params

def _init_records(self):
"""Initialize running records used to compute loss details running
means"""
if self._history is not None:
train_cols = [c for c in self._history.columns if 'train_' in c]
val_cols = [c for c in self._history.columns if 'val_' in c]
self._train_record = self._history[train_cols].iloc[-1:]
self._train_record = self._train_record.reset_index(drop=True)
self._val_record = self._history[val_cols].iloc[-1:]
self._val_record = self._val_record.reset_index(drop=True)

def get_high_res_exo_input(self, high_res):
"""Get exogenous feature data from high_res

Expand Down Expand Up @@ -1077,49 +1091,39 @@ def get_optimizer_state(cls, optimizer):
return state

@staticmethod
def update_loss_details(loss_details, new_data, batch_len, prefix=None):
def update_loss_details(record, new_data, max_batches, prefix=None):
"""Update a dictionary of loss_details with loss information from a new
batch.

Parameters
----------
loss_details : dict
Namespace of the breakdown of loss components where each value is a
running average at the current state in the epoch.
record : pd.DataFrame
Details for the last N batches, where N is the number of batches in
an epoch, used to compute the running means.
new_data : dict
Namespace of the breakdown of loss components for a single new
batch.
batch_len : int
Length of the incoming batch.
max_batches : int
Maximum number of batches to use for the running mean of loss
details
prefix : None | str
Option to prefix the names of the loss data when saving to the
loss_details dictionary.
loss_details dictionary. This is usually 'train_' or 'val_'

Returns
-------
loss_details : dict
Same as input loss_details but with running averages updated.
record : pd.DataFrame
Same as input with details from ``new_data`` added and only the
last ``max_batches`` rows kept.
"""
assert 'n_obs' in loss_details, 'loss_details must have n_obs to start'
prior_n_obs = loss_details['n_obs']
new_n_obs = prior_n_obs + batch_len

new_index = 0 if len(record) == 0 else record.index[-1] + 1
for k, v in new_data.items():
key = k if prefix is None else prefix + k
# only add prefix if key doesn't already include the prefix - no
# point in adding 'train_' to keys like 'disc_train_frac'
key = k if prefix is None or prefix in k else prefix + k
new_value = numpy_if_tensor(v)

if key in loss_details:
saved_value = loss_details[key]
saved_value *= prior_n_obs
saved_value += batch_len * new_value
saved_value /= new_n_obs
loss_details[key] = saved_value
else:
loss_details[key] = new_value

loss_details['n_obs'] = new_n_obs

return loss_details
record.loc[new_index, key] = new_value
return record.iloc[-max_batches:]

@staticmethod
def log_loss_details(loss_details, level='INFO'):
Expand All @@ -1134,15 +1138,11 @@ def log_loss_details(loss_details, level='INFO'):
Log level (e.g. INFO, DEBUG)
"""
for k, v in sorted(loss_details.items()):
if k != 'n_obs':
if isinstance(v, str):
msg_format = '\t{}: {}'
else:
msg_format = '\t{}: {:.2e}'
if level.lower() == 'info':
logger.info(msg_format.format(k, v))
else:
logger.debug(msg_format.format(k, v))
msg_format = '\t{}: {}' if isinstance(v, str) else '\t{}: {:.2e}'
if level.lower() == 'info':
logger.info(msg_format.format(k, v))
else:
logger.debug(msg_format.format(k, v))

@staticmethod
def early_stop(history, column, threshold=0.005, n_epoch=5):
Expand Down Expand Up @@ -1256,9 +1256,8 @@ def finish_epoch(
"""
self.log_loss_details(loss_details)
self._history.at[epoch, 'elapsed_time'] = time.time() - t0
for key, value in loss_details.items():
if key != 'n_obs':
self._history.at[epoch, key] = value
entry = np.vstack(list(loss_details.values())).T
self._history.loc[epoch, list(loss_details.keys())] = entry

last_epoch = epoch == epochs[-1]
chp = checkpoint_int is not None and (epoch % checkpoint_int) == 0
Expand Down Expand Up @@ -1588,6 +1587,21 @@ def _tf_generate(self, low_res, hi_res_exo=None):

return hi_res

def _get_hr_exo_and_loss(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could go either way on this, but knee jerk reaction is that breaking these lines out into a separate function in a different file just makes the stack trace deeper for little benefit. This function is only called in one place in a different file in a relatively short parent function. Seems like we could leave it as-is for less nesting functions? My gut feeling is that three direct function calls without any logic is portable enough to not be packaged into a separate function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of these extractions are motivated by the work on models with observations. I could delay this until that PR if you prefer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess i would have to see that work too, but i'd be shocked if a 14 line function really helps reduce the burden of 3 function calls? I really think we should just call the 3 functions directly. More nested functions reduces docstring quality and makes it way harder to trace args/kwargs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes ~50 lines of duplication in the obs branch but we can decide if it's worth doing in that PR.

self,
low_res,
hi_res_true,
**calc_loss_kwargs,
):
"""Get high-resolution exogenous data, generate synthetic output, and
compute loss."""
hi_res_exo = self.get_high_res_exo_input(hi_res_true)
hi_res_gen = self._tf_generate(low_res, hi_res_exo)
loss, loss_details = self.calc_loss(
hi_res_true, hi_res_gen, **calc_loss_kwargs
)
return loss, loss_details, hi_res_gen, hi_res_exo

@tf.function
def get_single_grad(
self,
Expand Down
Loading
Loading