Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sup3r.preprocessing.data_handlers import ExoData
from sup3r.preprocessing.utilities import numpy_if_tensor
from sup3r.utilities import VERSION_RECORD
from sup3r.utilities.utilities import camel_to_underscore, safe_cast
from sup3r.utilities.utilities import Timer, camel_to_underscore, safe_cast

from .utilities import SUP3R_LAYERS, SUP3R_OBS_LAYERS, TensorboardMixIn

Expand All @@ -40,11 +40,12 @@ def __init__(self):
super().__init__()
self.gpu_list = tf.config.list_physical_devices('GPU')
self.default_device = '/cpu:0' if len(self.gpu_list) == 0 else '/gpu:0'
self._version_record = VERSION_RECORD
self.timer = Timer()
self.name = None
self._meta = None
self.loss_name = None
self.loss_fun = None
self._version_record = VERSION_RECORD
self._meta = None
self._history = None
self._optimizer = None
self._gen = None
Expand Down
28 changes: 15 additions & 13 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def train(
adaptive_update_bounds=(0.9, 0.99),
adaptive_update_fraction=0.0,
multi_gpu=False,
tensorboard_log=True,
tensorboard_log=False,
tensorboard_profile=False,
):
"""Train the GAN model on real low res data and real high res data
Expand Down Expand Up @@ -1062,8 +1062,9 @@ def _post_batch(self, ib, b_loss_details, n_batches, previous_means):
prefix='train_',
)

self.dict_to_tensorboard(b_loss_details)
self.dict_to_tensorboard(self.timer.log)
if self._tb_writer is not None:
self.dict_to_tensorboard(b_loss_details)
self.dict_to_tensorboard(self.timer.log)

trained_gen = bool(self._train_record['gen_train_frac'].values[-1])
trained_disc = bool(self._train_record['disc_train_frac'].values[-1])
Expand Down Expand Up @@ -1154,14 +1155,15 @@ def _train_epoch(
tf.summary.trace_on(graph=True, profiler=True)

for ib, batch in enumerate(batch_handler):
start = time.time()

b_loss_details = {}
loss_disc = loss_means['train_loss_disc']
disc_too_good = loss_disc <= disc_th_low
disc_too_bad = (loss_disc > disc_th_high) and train_disc
gen_too_good = disc_too_bad

start = time.time()
b_loss_details = self._train_batch(
b_loss_details = self.timer(self._train_batch, log=True)(
batch,
train_gen,
only_gen,
Expand All @@ -1172,14 +1174,14 @@ def _train_epoch(
weight_gen_advers,
multi_gpu,
)
elapsed = time.time() - start
logger.info('Finished batch in {:.4f} seconds'.format(elapsed))

loss_means = self._post_batch(
ib,
b_loss_details,
len(batch_handler),
loss_means,

loss_means = self.timer(self._post_batch, log=True)(
ib, b_loss_details, len(batch_handler), loss_means
)

logger.info(
f'Finished batch step {ib + 1} / {len(batch_handler)} in '
f'{time.time() - start:.4f} seconds'
)

self.total_batches += len(batch_handler)
Expand Down
13 changes: 6 additions & 7 deletions sup3r/models/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,12 @@ def dict_to_tensorboard(self, entry):
entry: dict
Dictionary of values to write to tensorboard log file
"""
if self._tb_writer is not None:
with self._tb_writer.as_default():
for name, value in entry.items():
if isinstance(value, str):
tf.summary.text(name, value, self.total_batches)
else:
tf.summary.scalar(name, value, self.total_batches)
with self._tb_writer.as_default():
for name, value in entry.items():
if isinstance(value, str):
tf.summary.text(name, value, self.total_batches)
else:
tf.summary.scalar(name, value, self.total_batches)

def profile_to_tensorboard(self, name):
"""Write profile data to tensorboard log file.
Expand Down
4 changes: 4 additions & 0 deletions sup3r/postprocessing/writers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,10 @@ def get_lat_lon(cls, low_res_lat_lon, shape):
(spatial_1, spatial_2, 2)
Last dimension has ordering (lat, lon)
"""
assert (
low_res_lat_lon.shape[0] > 1 and low_res_lat_lon.shape[1] > 1
), 'low res lat/lon must have at least 2 rows and 2 columns'

logger.debug('Getting high resolution lat / lon grid')

# ensure lons are between -180 and 180
Expand Down
28 changes: 10 additions & 18 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,15 @@ def get_batch(self) -> DsetTuple:
"""Get batch from queue or directly from a ``Sampler`` through
``sample_batch``."""
if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0:
return self.sample_batch()
return self.queue.dequeue()
samples = self.sample_batch()
else:
samples = self.queue.dequeue()
if self.sample_shape[2] == 1:
if isinstance(samples, (list, tuple)):
samples = tuple(s[..., 0, :] for s in samples)
else:
samples = samples[..., 0, :]
return self.post_proc(samples)

@property
def running(self):
Expand Down Expand Up @@ -275,7 +282,6 @@ def enqueue_batches(self) -> None:
log_time = time.time()
while self.running:
needed = max(self.queue_cap - self.queue_len, 0)
needed = min(self.max_workers, needed)
if needed > 0:
batches = self.sample_batches(n_batches=needed)
if needed > 1 and self.max_workers > 1:
Expand All @@ -300,22 +306,8 @@ def __next__(self) -> DsetTuple:
Batch object with batch.low_res and batch.high_res attributes
"""
if self._batch_count < self.n_batches:
self.timer.start()
samples = self.get_batch()
if self.sample_shape[2] == 1:
if isinstance(samples, (list, tuple)):
samples = tuple(s[..., 0, :] for s in samples)
else:
samples = samples[..., 0, :]
batch = self.post_proc(samples)
self.timer.stop()
batch = self.timer(self.get_batch, log=self.verbose)()
self._batch_count += 1
if self.verbose:
logger.debug(
'Batch step %s finished in %s.',
self._batch_count,
self.timer.elapsed_str,
)
else:
raise StopIteration
return batch
Expand Down
77 changes: 50 additions & 27 deletions sup3r/preprocessing/rasterizers/exo.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def __post_init__(self):
self._hr_lat_lon = None
self._source_lat_lon = None
self._hr_time_index = None
self._nn = None
self._dists = None
self.input_handler_kwargs = self.input_handler_kwargs or {}
self.source_handler_kwargs = self.source_handler_kwargs or {}
InputHandler = get_input_handler_class(self.input_handler_name)
Expand All @@ -151,9 +153,9 @@ def __post_init__(self):
@property
def source_handler(self):
"""Get the Loader object that handles the exogenous data file."""
assert (
self.source_files is not None
), 'source_files must be provided to BaseExoRasterizer'
assert self.source_files is not None, (
'source_files must be provided to BaseExoRasterizer'
)
if self._source_handler is None:
self._source_handler = Loader(
self.source_files,
Expand Down Expand Up @@ -278,6 +280,11 @@ def get_distance_upper_bound(self):
"""Maximum distance (float) to map high-resolution data from
source_files to the low-resolution file_paths input."""
if self.distance_upper_bound is None:
assert self.hr_lat_lon.shape[0] > 1, (
'hr_lat_lon must have at least 2 lat points to calculate '
'distance upper bound. Either expand the grid or provide a '
'distance_upper_bound explicitly.'
)
diff = da.diff(self.hr_lat_lon, axis=0)
diff = da.abs(da.median(diff, axis=0)).max()
self.distance_upper_bound = np.asarray(diff)
Expand All @@ -296,15 +303,29 @@ def tree(self):
self._tree = KDTree(self.hr_lat_lon.reshape((-1, 2)))
return self._tree

def query_tree(self, lat_lon):
"""Query the KDTree for the nearest neighbor indices and distances
for the given lat_lon points."""
return self.tree.query(
lat_lon,
distance_upper_bound=self.get_distance_upper_bound(),
)

@property
def nn(self):
"""Get the nearest neighbor indices. This uses a single neighbor by
default"""
_, nn = self.tree.query(
self.source_lat_lon,
distance_upper_bound=self.get_distance_upper_bound(),
)
return nn
if self._nn is None:
self._dists, self._nn = self.query_tree(self.source_lat_lon)
return self._nn

@property
def dists(self):
"""Get the nearest neighbor indices. This uses a single neighbor by
default"""
if self._dists is None:
self._dists, self._nn = self.query_tree(self.source_lat_lon)
return self._dists

@property
def data(self):
Expand Down Expand Up @@ -371,6 +392,10 @@ def get_data(self):
data_vars = {self.feature: data_vars}
return Sup3rX(xr.Dataset(coords=self.coords, data_vars=data_vars))

def _mean_fill(self, x):
"""Compute standard average for a group of data."""
return x[self.feature].mean()

def _get_data_2d(self):
"""Get a raster of source values corresponding to the high-resolution
grid (the file_paths input grid * s_enhance). This is used for time
Expand Down Expand Up @@ -413,27 +438,27 @@ def _get_data_3d(self):
out : np.ndarray
3D array of source data with shape (lats, lons, time).
"""
assert (
len(self.source_data.shape) == 2 and self.source_data.shape[1] > 1
)
target_tmask = self.hr_time_index.isin(self.source_handler.time_index)
source_tmask = self.source_handler.time_index.isin(self.hr_time_index)
data = self.source_data[:, source_tmask]
rows = pd.MultiIndex.from_product(
[self.nn, range(data.shape[-1])], names=['gid_target', 'time']
)
df = pd.DataFrame({self.feature: data.flatten()}, index=rows)
dists = np.repeat(self.dists[:, None], data.shape[-1], axis=1)
n_target = np.prod(self.hr_shape[:-1])
df = pd.DataFrame(
{self.feature: data.flatten(), 'distance': dists.flatten()},
index=rows,
)
df = df[df.index.get_level_values(0) != n_target].sort_values(
'gid_target'
)
df = df.groupby(['gid_target', 'time']).apply(self._mean_fill)
out = np.full(
(n_target, len(self.hr_time_index)), np.nan, dtype=np.float32
)
gids = df.index.get_level_values(0)
df = df[gids != n_target]
df = df.sort_values('gid_target')
df = df.groupby(['gid_target', 'time']).mean()
inds = gids.unique().values[gids.unique() != n_target][:, None]
vals = df[self.feature].values.reshape((-1, data.shape[-1]))
out[inds, target_tmask] = vals
inds = np.array(df.index.get_level_values(0).unique())[:, None]
out[inds, target_tmask] = df.values.reshape((-1, data.shape[-1]))
out = out.reshape((*self.hr_shape[:-1], -1))
return out

Expand Down Expand Up @@ -487,12 +512,6 @@ def _get_data_3d(self):
"""
hr_data = super()._get_data_3d()
gid_mask = self.nn != np.prod(self.hr_shape[:-1])
logger.info(
'Found {} of {} observations within {:4f} degrees of high-'
'resolution grid points.'.format(
gid_mask.sum(), len(gid_mask), self.distance_upper_bound
)
)
cover_frac = (~np.isnan(hr_data)).sum() / hr_data.size
if cover_frac == 0:
msg = (
Expand All @@ -502,10 +521,14 @@ def _get_data_3d(self):
warn(msg)
logger.warning(msg)
else:
msg = 'Observations cover {:.4e} of the high-res domain.'.format(
compute_if_dask(cover_frac)
msg = (
f'Found {gid_mask.sum()} of {len(gid_mask)} observations '
f'within {self.distance_upper_bound:4f} degrees of '
'high-resolution grid points. Observations cover '
f'{compute_if_dask(cover_frac):.4e} of the high-res domain.'
)
logger.info(msg)

self.fill_nans = False # override parent attribute to not fill nans
return hr_data

Expand Down
Loading