Skip to content

Commit 0f85aeb

Browse files
authored
Merge pull request #184 from NeuroDiffGym/v0.6.1
Hot fix: solve a fatal compatibility issue with torch v1.13
2 parents 4eac45c + 36966e8 commit 0f85aeb

11 files changed

+121
-62
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,6 @@ docs/_build/
122122
_test/
123123

124124
.DS_Store
125+
126+
# Tensorboard
127+
runs/

neurodiffeq/_version_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def warn_deprecate_class(new_class):
99
:return: a function that, when called, acts as if it is a class constructor
1010
:rtype: callable
1111
"""
12+
1213
@functools.wraps(new_class)
1314
def old_class_getter(*args, **kwargs):
1415
warnings.warn(f"This class name is deprecated, use {new_class} instead", FutureWarning)
@@ -26,12 +27,15 @@ def deprecated_alias(**aliases):
2627
:return: A decorated function that can receive either `old_name` or `new_name` as input
2728
:rtype: function
2829
"""
30+
2931
def deco(f):
3032
@functools.wraps(f) # preserves signature and docstring
3133
def wrapper(*args, **kwargs):
3234
_rename_kwargs(f.__name__, kwargs, aliases)
3335
return f(*args, **kwargs)
36+
3437
return wrapper
38+
3539
return deco
3640

3741

@@ -40,5 +44,5 @@ def _rename_kwargs(func_name, kwargs, aliases):
4044
if alias in kwargs:
4145
if new in kwargs:
4246
raise KeyError(f'{func_name} received both `{alias}` (deprecated) and `{new}` (recommended)')
43-
warnings.warn(f'The argument `{alias}` is deprecated; use `{new}` instead for {func_name}.', FutureWarning)
44-
kwargs[new] = kwargs.pop(alias)
47+
warnings.warn(f'The argument `{alias}` is deprecated for {func_name}; use `{new}` instead.', FutureWarning)
48+
kwargs[new] = kwargs.pop(alias)

neurodiffeq/callbacks.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,19 +262,19 @@ def __call__(self, solver):
262262
)
263263

264264

265-
class SetCriterion(ActionCallback):
265+
class SetLossFn(ActionCallback):
266266
r"""A callback that sets the ``criterion`` (a.k.a. loss function) of the solver.
267267
Best used together with a condition callback.
268268
269-
:param criterion:
269+
:param loss_fn:
270270
The loss function to be set for the solver. It can be
271271
272272
- An instance of ``torch.nn.modules.loss._Loss``
273273
which computes loss of the PDE/ODE residuals against a zero tensor.
274274
- A callable object which maps residuals, function values, and input coordinates to a scalar loss; or
275275
- A str which is present in ``neurodiffeq.losses._losses.keys()``.
276276
277-
:type criterion: ``torch.nn.modules.loss._Loss`` or callable or str.
277+
:type loss_fn: ``torch.nn.modules.loss._Loss`` or callable or str.
278278
:param reset:
279279
If True, the criterion will be reset every time the callback is called.
280280
Otherwise, the criterion will only be set once.
@@ -284,19 +284,22 @@ class SetCriterion(ActionCallback):
284284
:type logger: str or ``logging.Logger``
285285
"""
286286

287-
def __init__(self, criterion, reset=False, logger=None):
288-
super(SetCriterion, self).__init__(logger=logger)
289-
self.criterion = criterion
287+
@deprecated_alias(criterion='loss_fn')
288+
def __init__(self, loss_fn, reset=False, logger=None):
289+
super(SetLossFn, self).__init__(logger=logger)
290+
self.loss_fn = loss_fn
290291
self.reset = reset
291292
self.called = False
292293

293294
def __call__(self, solver):
294295
if self.reset or (not self.called):
295296
self.called = True
296297
# noinspection PyProtectedMember
297-
solver._set_criterion(self.criterion)
298+
solver._set_loss_fn(self.loss_fn)
298299

299300

301+
SetCriterion = warn_deprecate_class(SetLossFn)
302+
300303
class SetOptimizer(ActionCallback):
301304
r"""A callback that sets the optimizer of the solver. Best used together with a condition callback.
302305

neurodiffeq/generators.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -671,10 +671,8 @@ def _internal_vars(self) -> dict:
671671
class PredefinedGenerator(BaseGenerator):
672672
"""A generator for generating points that are fixed and predefined.
673673
674-
:param xs: The x-dimension of the trianing points
675-
:type xs: `torch.Tensor`
676-
:param ys: The y-dimension of the training points
677-
:type ys: `torch.Tensor`
674+
:param xs: training points that will be returned
675+
:type xs: Tuple[`torch.Tensor`]
678676
"""
679677

680678
def __init__(self, *xs):

neurodiffeq/ode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ class CustomSolver1D(Solver1D):
296296
train_generator=train_generator,
297297
valid_generator=valid_generator,
298298
optimizer=optimizer,
299-
criterion=criterion,
299+
loss_fn=criterion,
300300
n_batches_train=n_batches_train,
301301
n_batches_valid=n_batches_valid,
302302
metrics=metrics,

neurodiffeq/pde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ class CustomSolver2D(Solver2D):
321321
train_generator=train_generator,
322322
valid_generator=valid_generator,
323323
optimizer=optimizer,
324-
criterion=criterion,
324+
loss_fn=criterion,
325325
n_batches_train=n_batches_train,
326326
n_batches_valid=n_batches_valid,
327327
metrics=metrics,

neurodiffeq/pde_spherical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def enforcer(net, cond, points):
263263
valid_generator=valid_generator,
264264
analytic_solutions=analytic_solutions,
265265
optimizer=optimizer,
266-
criterion=criterion,
266+
loss_fn=criterion,
267267
n_batches_train=1,
268268
n_batches_valid=1,
269269
enforcer=enforcer,

neurodiffeq/solvers.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626

2727

2828
def _requires_closure(optimizer):
29-
return inspect.signature(optimizer.step).parameters.get('closure').default == inspect._empty
29+
# starting from torch v1.13, simple optimizers no longer have a `closure` argument
30+
closure_param = inspect.signature(optimizer.step).parameters.get('closure')
31+
return closure_param and closure_param.default == inspect._empty
3032

3133

3234
class BaseSolver(ABC, PretrainedSolver):
@@ -60,7 +62,7 @@ class BaseSolver(ABC, PretrainedSolver):
6062
:param optimizer:
6163
The optimizer to be used for training.
6264
:type optimizer: `torch.nn.optim.Optimizer`, optional
63-
:param criterion:
65+
:param loss_fn:
6466
The loss function used for training.
6567
6668
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
@@ -72,7 +74,7 @@ class BaseSolver(ABC, PretrainedSolver):
7274
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
7375
so that backpropagation can be performed.
7476
75-
:type criterion:
77+
:type loss_fn:
7678
str or `torch.nn.moduesl.loss._Loss` or callable
7779
:param n_batches_train:
7880
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
@@ -107,9 +109,10 @@ class BaseSolver(ABC, PretrainedSolver):
107109
:type shuffle: bool
108110
"""
109111

112+
@deprecated_alias(criterion='loss_fn')
110113
def __init__(self, diff_eqs, conditions,
111114
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None,
112-
optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4,
115+
optimizer=None, loss_fn=None, n_batches_train=1, n_batches_valid=4,
113116
metrics=None, n_input_units=None, n_output_units=None,
114117
# deprecated arguments are listed below
115118
shuffle=None, batch_size=None):
@@ -176,7 +179,7 @@ def analytic_mse(*args):
176179
self.metrics_history.update({'valid__' + name: [] for name in self.metrics_fn})
177180

178181
self.optimizer = optimizer if optimizer else Adam(set(chain.from_iterable(n.parameters() for n in self.nets)))
179-
self._set_criterion(criterion)
182+
self._set_loss_fn(loss_fn)
180183

181184
def make_pair_dict(train=None, valid=None):
182185
return {'train': train, 'valid': valid}
@@ -203,15 +206,15 @@ def make_pair_dict(train=None, valid=None):
203206
# the _phase variable is registered for callback functions to access
204207
self._phase = None
205208

206-
def _set_criterion(self, criterion):
209+
def _set_loss_fn(self, criterion):
207210
if criterion is None:
208-
self.criterion = lambda r, f, x: (r ** 2).mean()
211+
self.loss_fn = lambda r, f, x: (r ** 2).mean()
209212
elif isinstance(criterion, nn.modules.loss._Loss):
210-
self.criterion = lambda r, f, x: criterion(r, torch.zeros_like(r))
213+
self.loss_fn = lambda r, f, x: criterion(r, torch.zeros_like(r))
211214
elif isinstance(criterion, str):
212-
self.criterion = _losses[criterion.lower()]
215+
self.loss_fn = _losses[criterion.lower()]
213216
elif callable(criterion):
214-
self.criterion = criterion
217+
self.loss_fn = criterion
215218
else:
216219
raise TypeError(f"Unknown type of criterion {type(criterion)}")
217220

@@ -236,6 +239,24 @@ def _batch_examples(self):
236239
)
237240
return self._batch
238241

242+
@property
243+
def criterion(self):
244+
warnings.warn(
245+
f'`{self.__class__.__name__}`.criterion is a deprecated alias for `{self.__class__.__name__}.loss_fn`.'
246+
f'The alias is only meant to be accessed by certain functions in `neurodiffeq.solver_utils` '
247+
f'until proper fixes are made; by which time this alias will be removed.'
248+
)
249+
return self.loss_fn
250+
251+
@criterion.setter
252+
def criterion(self, loss_fn):
253+
warnings.warn(
254+
f'`{self.__class__.__name__}`.criterion is a deprecated alias for `{self.__class__.__name__}.loss_fn`.'
255+
f'The alias is only meant to be accessed by certain functions in `neurodiffeq.solver_utils` '
256+
f'until proper fixes are made; by which time this alias will be removed.'
257+
)
258+
self.loss_fn = loss_fn
259+
239260
def compute_func_val(self, net, cond, *coordinates):
240261
r"""Compute the function value evaluated on the points specified by ``coordinates``.
241262
@@ -352,7 +373,7 @@ def closure(zero_grad=True):
352373
residuals = self.diff_eqs(*funcs, *batch)
353374
residuals = torch.cat(residuals, dim=1)
354375
try:
355-
loss = self.criterion(residuals, funcs, batch) + self.additional_loss(residuals, funcs, batch)
376+
loss = self.loss_fn(residuals, funcs, batch) + self.additional_loss(residuals, funcs, batch)
356377
except TypeError as e:
357378
warnings.warn(
358379
"You might need to update your code. "
@@ -507,7 +528,8 @@ def _get_internal_variables(self):
507528
"metrics": self.metrics_fn,
508529
"n_batches": self.n_batches,
509530
"best_nets": self.best_nets,
510-
"criterion": self.criterion,
531+
"criterion": self.loss_fn,
532+
"loss_fn": self.loss_fn,
511533
"conditions": self.conditions,
512534
"global_epoch": self.global_epoch,
513535
"lowest_loss": self.lowest_loss,
@@ -766,7 +788,7 @@ class SolverSpherical(BaseSolver):
766788
Optimizer to be used for training.
767789
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
768790
:type optimizer: ``torch.nn.optim.Optimizer``, optional
769-
:param criterion:
791+
:param loss_fn:
770792
The loss function used for training.
771793
772794
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
@@ -778,7 +800,7 @@ class SolverSpherical(BaseSolver):
778800
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
779801
so that backpropagation can be performed.
780802
781-
:type criterion:
803+
:type loss_fn:
782804
str or `torch.nn.moduesl.loss._Loss` or callable
783805
:param n_batches_train:
784806
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
@@ -820,7 +842,7 @@ class SolverSpherical(BaseSolver):
820842

821843
def __init__(self, pde_system, conditions, r_min=None, r_max=None,
822844
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None,
823-
optimizer=None, criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, enforcer=None,
845+
optimizer=None, loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, enforcer=None,
824846
n_output_units=1,
825847
# deprecated arguments are listed below
826848
shuffle=None, batch_size=None):
@@ -848,7 +870,7 @@ def __init__(self, pde_system, conditions, r_min=None, r_max=None,
848870
valid_generator=valid_generator,
849871
analytic_solutions=analytic_solutions,
850872
optimizer=optimizer,
851-
criterion=criterion,
873+
loss_fn=loss_fn,
852874
n_batches_train=n_batches_train,
853875
n_batches_valid=n_batches_valid,
854876
metrics=metrics,
@@ -1025,7 +1047,7 @@ class Solver1D(BaseSolver):
10251047
Optimizer to be used for training.
10261048
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
10271049
:type optimizer: ``torch.nn.optim.Optimizer``, optional
1028-
:param criterion:
1050+
:param loss_fn:
10291051
The loss function used for training.
10301052
10311053
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
@@ -1037,7 +1059,7 @@ class Solver1D(BaseSolver):
10371059
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
10381060
so that backpropagation can be performed.
10391061
1040-
:type criterion:
1062+
:type loss_fn:
10411063
str or `torch.nn.moduesl.loss._Loss` or callable
10421064
:param n_batches_train:
10431065
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
@@ -1073,7 +1095,7 @@ class Solver1D(BaseSolver):
10731095

10741096
def __init__(self, ode_system, conditions, t_min=None, t_max=None,
10751097
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None,
1076-
criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
1098+
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
10771099
# deprecated arguments are listed below
10781100
batch_size=None, shuffle=None):
10791101

@@ -1098,7 +1120,7 @@ def __init__(self, ode_system, conditions, t_min=None, t_max=None,
10981120
valid_generator=valid_generator,
10991121
analytic_solutions=analytic_solutions,
11001122
optimizer=optimizer,
1101-
criterion=criterion,
1123+
loss_fn=loss_fn,
11021124
n_batches_train=n_batches_train,
11031125
n_batches_valid=n_batches_valid,
11041126
metrics=metrics,
@@ -1209,7 +1231,7 @@ class BundleSolver1D(BaseSolver):
12091231
Optimizer to be used for training.
12101232
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
12111233
:type optimizer: ``torch.nn.optim.Optimizer``, optional
1212-
:param criterion:
1234+
:param loss_fn:
12131235
The loss function used for training.
12141236
12151237
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
@@ -1221,7 +1243,7 @@ class BundleSolver1D(BaseSolver):
12211243
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
12221244
so that backpropagation can be performed.
12231245
1224-
:type criterion:
1246+
:type loss_fn:
12251247
str or `torch.nn.moduesl.loss._Loss` or callable
12261248
:param n_batches_train:
12271249
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
@@ -1258,7 +1280,7 @@ class BundleSolver1D(BaseSolver):
12581280
def __init__(self, ode_system, conditions, t_min, t_max,
12591281
theta_min=None, theta_max=None,
12601282
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None,
1261-
criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
1283+
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
12621284
# deprecated arguments are listed below
12631285
batch_size=None, shuffle=None):
12641286

@@ -1319,7 +1341,7 @@ def non_var_filter(*variables):
13191341
valid_generator=valid_generator,
13201342
analytic_solutions=analytic_solutions,
13211343
optimizer=optimizer,
1322-
criterion=criterion,
1344+
loss_fn=loss_fn,
13231345
n_batches_train=n_batches_train,
13241346
n_batches_valid=n_batches_valid,
13251347
metrics=metrics,
@@ -1420,7 +1442,7 @@ class Solver2D(BaseSolver):
14201442
Optimizer to be used for training.
14211443
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
14221444
:type optimizer: ``torch.nn.optim.Optimizer``, optional
1423-
:param criterion:
1445+
:param loss_fn:
14241446
The loss function used for training.
14251447
14261448
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
@@ -1432,7 +1454,7 @@ class Solver2D(BaseSolver):
14321454
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
14331455
so that backpropagation can be performed.
14341456
1435-
:type criterion:
1457+
:type loss_fn:
14361458
str or `torch.nn.moduesl.loss._Loss` or callable
14371459
:param n_batches_train:
14381460
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
@@ -1468,7 +1490,7 @@ class Solver2D(BaseSolver):
14681490

14691491
def __init__(self, pde_system, conditions, xy_min=None, xy_max=None,
14701492
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None,
1471-
criterion=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
1493+
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
14721494
# deprecated arguments are listed below
14731495
batch_size=None, shuffle=None):
14741496

@@ -1493,7 +1515,7 @@ def __init__(self, pde_system, conditions, xy_min=None, xy_max=None,
14931515
valid_generator=valid_generator,
14941516
analytic_solutions=analytic_solutions,
14951517
optimizer=optimizer,
1496-
criterion=criterion,
1518+
loss_fn=loss_fn,
14971519
n_batches_train=n_batches_train,
14981520
n_batches_valid=n_batches_valid,
14991521
metrics=metrics,

0 commit comments

Comments
 (0)