Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sup3r.preprocessing.data_handlers import ExoData
from sup3r.preprocessing.utilities import numpy_if_tensor
from sup3r.utilities import VERSION_RECORD
from sup3r.utilities.utilities import camel_to_underscore, safe_cast
from sup3r.utilities.utilities import Timer, camel_to_underscore, safe_cast

from .utilities import SUP3R_LAYERS, SUP3R_OBS_LAYERS, TensorboardMixIn

Expand All @@ -40,11 +40,12 @@ def __init__(self):
super().__init__()
self.gpu_list = tf.config.list_physical_devices('GPU')
self.default_device = '/cpu:0' if len(self.gpu_list) == 0 else '/gpu:0'
self._version_record = VERSION_RECORD
self.timer = Timer()
self.name = None
self._meta = None
self.loss_name = None
self.loss_fun = None
self._version_record = VERSION_RECORD
self._meta = None
self._history = None
self._optimizer = None
self._gen = None
Expand Down
28 changes: 15 additions & 13 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def train(
adaptive_update_bounds=(0.9, 0.99),
adaptive_update_fraction=0.0,
multi_gpu=False,
tensorboard_log=True,
tensorboard_log=False,
tensorboard_profile=False,
):
"""Train the GAN model on real low res data and real high res data
Expand Down Expand Up @@ -1062,8 +1062,9 @@ def _post_batch(self, ib, b_loss_details, n_batches, previous_means):
prefix='train_',
)

self.dict_to_tensorboard(b_loss_details)
self.dict_to_tensorboard(self.timer.log)
if self._tb_writer is not None:
self.dict_to_tensorboard(b_loss_details)
self.dict_to_tensorboard(self.timer.log)

trained_gen = bool(self._train_record['gen_train_frac'].values[-1])
trained_disc = bool(self._train_record['disc_train_frac'].values[-1])
Expand Down Expand Up @@ -1154,14 +1155,15 @@ def _train_epoch(
tf.summary.trace_on(graph=True, profiler=True)

for ib, batch in enumerate(batch_handler):
start = time.time()

b_loss_details = {}
loss_disc = loss_means['train_loss_disc']
disc_too_good = loss_disc <= disc_th_low
disc_too_bad = (loss_disc > disc_th_high) and train_disc
gen_too_good = disc_too_bad

start = time.time()
b_loss_details = self._train_batch(
b_loss_details = self.timer(self._train_batch, log=True)(
batch,
train_gen,
only_gen,
Expand All @@ -1172,14 +1174,14 @@ def _train_epoch(
weight_gen_advers,
multi_gpu,
)
elapsed = time.time() - start
logger.info('Finished batch in {:.4f} seconds'.format(elapsed))

loss_means = self._post_batch(
ib,
b_loss_details,
len(batch_handler),
loss_means,

loss_means = self.timer(self._post_batch, log=True)(
ib, b_loss_details, len(batch_handler), loss_means
)

logger.info(
f'Finished batch step {ib + 1} / {len(batch_handler)} in '
f'{time.time() - start:.4f} seconds'
)

self.total_batches += len(batch_handler)
Expand Down
13 changes: 6 additions & 7 deletions sup3r/models/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,12 @@ def dict_to_tensorboard(self, entry):
entry: dict
Dictionary of values to write to tensorboard log file
"""
if self._tb_writer is not None:
with self._tb_writer.as_default():
for name, value in entry.items():
if isinstance(value, str):
tf.summary.text(name, value, self.total_batches)
else:
tf.summary.scalar(name, value, self.total_batches)
with self._tb_writer.as_default():
for name, value in entry.items():
if isinstance(value, str):
tf.summary.text(name, value, self.total_batches)
else:
tf.summary.scalar(name, value, self.total_batches)

def profile_to_tensorboard(self, name):
"""Write profile data to tensorboard log file.
Expand Down
4 changes: 4 additions & 0 deletions sup3r/postprocessing/writers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,10 @@ def get_lat_lon(cls, low_res_lat_lon, shape):
(spatial_1, spatial_2, 2)
Last dimension has ordering (lat, lon)
"""
assert (
low_res_lat_lon.shape[0] > 1 and low_res_lat_lon.shape[1] > 1
), 'low res lat/lon must have at least 2 rows and 2 columns'

logger.debug('Getting high resolution lat / lon grid')

# ensure lons are between -180 and 180
Expand Down
28 changes: 10 additions & 18 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,15 @@ def get_batch(self) -> DsetTuple:
"""Get batch from queue or directly from a ``Sampler`` through
``sample_batch``."""
if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0:
return self.sample_batch()
return self.queue.dequeue()
samples = self.sample_batch()
else:
samples = self.queue.dequeue()
if self.sample_shape[2] == 1:
if isinstance(samples, (list, tuple)):
samples = tuple(s[..., 0, :] for s in samples)
else:
samples = samples[..., 0, :]
return self.post_proc(samples)

@property
def running(self):
Expand Down Expand Up @@ -275,7 +282,6 @@ def enqueue_batches(self) -> None:
log_time = time.time()
while self.running:
needed = max(self.queue_cap - self.queue_len, 0)
needed = min(self.max_workers, needed)
if needed > 0:
batches = self.sample_batches(n_batches=needed)
if needed > 1 and self.max_workers > 1:
Expand All @@ -300,22 +306,8 @@ def __next__(self) -> DsetTuple:
Batch object with batch.low_res and batch.high_res attributes
"""
if self._batch_count < self.n_batches:
self.timer.start()
samples = self.get_batch()
if self.sample_shape[2] == 1:
if isinstance(samples, (list, tuple)):
samples = tuple(s[..., 0, :] for s in samples)
else:
samples = samples[..., 0, :]
batch = self.post_proc(samples)
self.timer.stop()
batch = self.timer(self.get_batch, log=self.verbose)()
self._batch_count += 1
if self.verbose:
logger.debug(
'Batch step %s finished in %s.',
self._batch_count,
self.timer.elapsed_str,
)
else:
raise StopIteration
return batch
Expand Down
105 changes: 81 additions & 24 deletions sup3r/preprocessing/rasterizers/exo.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ class will output a topography raster corresponding to the file_paths
fill_nans : bool
Whether to fill nans in the output data. This should probably be True
for all cases except for sparse observation data.
agg_method : str
Method to use for aggregating source data to the target pixels. This
can be 'mean', 'idw' (inverse distance weighted average), or 'nn'
(nearest neighbor). The default is 'mean'. This is only used for 3D
data (time dependent) and will be ignored for 2D data (time
independent).
scale_factor : float
Scale factor to apply to the raw data from the source_files. This is
useful for scaling observation data which might systematically under or
Expand All @@ -125,6 +131,7 @@ class will output a topography raster corresponding to the file_paths
chunks: Optional[Union[str, dict]] = 'auto'
distance_upper_bound: Optional[int] = None
fill_nans: bool = True
agg_method: str = 'mean'
scale_factor: float = 1.0
max_workers: int = 1
verbose: bool = False
Expand All @@ -141,6 +148,8 @@ def __post_init__(self):
self._hr_lat_lon = None
self._source_lat_lon = None
self._hr_time_index = None
self._nn = None
self._dists = None
self.input_handler_kwargs = self.input_handler_kwargs or {}
self.source_handler_kwargs = self.source_handler_kwargs or {}
InputHandler = get_input_handler_class(self.input_handler_name)
Expand Down Expand Up @@ -278,6 +287,11 @@ def get_distance_upper_bound(self):
"""Maximum distance (float) to map high-resolution data from
source_files to the low-resolution file_paths input."""
if self.distance_upper_bound is None:
assert self.hr_lat_lon.shape[0] > 1, (
'hr_lat_lon must have at least 2 lat points to calculate '
'distance upper bound. Either expand the grid or provide a '
'distance_upper_bound explicitly.'
)
diff = da.diff(self.hr_lat_lon, axis=0)
diff = da.abs(da.median(diff, axis=0)).max()
self.distance_upper_bound = np.asarray(diff)
Expand All @@ -296,15 +310,29 @@ def tree(self):
self._tree = KDTree(self.hr_lat_lon.reshape((-1, 2)))
return self._tree

def query_tree(self, lat_lon):
"""Query the KDTree for the nearest neighbor indices and distances
for the given lat_lon points."""
return self.tree.query(
lat_lon,
distance_upper_bound=self.get_distance_upper_bound(),
)

@property
def nn(self):
"""Get the nearest neighbor indices. This uses a single neighbor by
default"""
_, nn = self.tree.query(
self.source_lat_lon,
distance_upper_bound=self.get_distance_upper_bound(),
)
return nn
if self._nn is None:
self._dists, self._nn = self.query_tree(self.source_lat_lon)
return self._nn

@property
def dists(self):
"""Get the nearest neighbor indices. This uses a single neighbor by
default"""
if self._dists is None:
self._dists, self._nn = self.query_tree(self.source_lat_lon)
return self._dists

@property
def data(self):
Expand Down Expand Up @@ -371,6 +399,25 @@ def get_data(self):
data_vars = {self.feature: data_vars}
return Sup3rX(xr.Dataset(coords=self.coords, data_vars=data_vars))

def _idw_fill(self, x):
"""Compute weighted average for a group of data."""
valid_mask = ~np.isnan(x[self.feature].values)
if valid_mask.sum() == 0:
return np.nan
weights = 1 / np.maximum(x['distance'], 1e-6)
return np.average(
x[self.feature][valid_mask],
weights=weights[valid_mask] / weights[valid_mask].sum(),
)

def _mean_fill(self, x):
"""Compute standard average for a group of data."""
return x[self.feature].mean()

def _nn_fill(self, x):
"""Select value with min distance."""
return x[self.feature].iloc[np.argmin(x['distance'])]

def _get_data_2d(self):
"""Get a raster of source values corresponding to the high-resolution
grid (the file_paths input grid * s_enhance). This is used for time
Expand Down Expand Up @@ -413,27 +460,39 @@ def _get_data_3d(self):
out : np.ndarray
3D array of source data with shape (lats, lons, time).
"""
assert (
len(self.source_data.shape) == 2 and self.source_data.shape[1] > 1
)
if self.agg_method == 'idw':
func = self._idw_fill
elif self.agg_method == 'nn':
func = self._nn_fill
elif self.agg_method == 'mean':
func = self._mean_fill
else:
raise ValueError(
f'Unknown aggregation method: {self.agg_method}. '
'Must be one of "idw", "nn", or "mean".'
)
logger.info('Using {} aggregation method'.format(self.agg_method))
target_tmask = self.hr_time_index.isin(self.source_handler.time_index)
source_tmask = self.source_handler.time_index.isin(self.hr_time_index)
data = self.source_data[:, source_tmask]
rows = pd.MultiIndex.from_product(
[self.nn, range(data.shape[-1])], names=['gid_target', 'time']
)
df = pd.DataFrame({self.feature: data.flatten()}, index=rows)
dists = np.repeat(self.dists[:, None], data.shape[-1], axis=1)
n_target = np.prod(self.hr_shape[:-1])
df = pd.DataFrame(
{self.feature: data.flatten(), 'distance': dists.flatten()},
index=rows,
)
df = df[df.index.get_level_values(0) != n_target].sort_values(
'gid_target'
)
df = df.groupby(['gid_target', 'time']).apply(func)
out = np.full(
(n_target, len(self.hr_time_index)), np.nan, dtype=np.float32
)
gids = df.index.get_level_values(0)
df = df[gids != n_target]
df = df.sort_values('gid_target')
df = df.groupby(['gid_target', 'time']).mean()
inds = gids.unique().values[gids.unique() != n_target][:, None]
vals = df[self.feature].values.reshape((-1, data.shape[-1]))
out[inds, target_tmask] = vals
inds = np.array(df.index.get_level_values(0).unique())[:, None]
out[inds, target_tmask] = df.values.reshape((-1, data.shape[-1]))
out = out.reshape((*self.hr_shape[:-1], -1))
return out

Expand Down Expand Up @@ -487,12 +546,6 @@ def _get_data_3d(self):
"""
hr_data = super()._get_data_3d()
gid_mask = self.nn != np.prod(self.hr_shape[:-1])
logger.info(
'Found {} of {} observations within {:4f} degrees of high-'
'resolution grid points.'.format(
gid_mask.sum(), len(gid_mask), self.distance_upper_bound
)
)
cover_frac = (~np.isnan(hr_data)).sum() / hr_data.size
if cover_frac == 0:
msg = (
Expand All @@ -502,10 +555,14 @@ def _get_data_3d(self):
warn(msg)
logger.warning(msg)
else:
msg = 'Observations cover {:.4e} of the high-res domain.'.format(
compute_if_dask(cover_frac)
msg = (
f'Found {gid_mask.sum()} of {len(gid_mask)} observations '
f'within {self.distance_upper_bound:4f} degrees of '
'high-resolution grid points. Observations cover '
f'{compute_if_dask(cover_frac):.4e} of the high-res domain.'
)
logger.info(msg)

self.fill_nans = False # override parent attribute to not fill nans
return hr_data

Expand Down
Loading