@@ -57,8 +57,19 @@ class BaseSolver(ABC, PretrainedSolver):
57
57
The optimizer to be used for training.
58
58
:type optimizer: `torch.nn.optim.Optimizer`, optional
59
59
:param criterion:
60
- A function that maps a PDE residual vector (torch tensor with shape (-1, 1)) to a scalar loss.
61
- :type criterion: callable, optional
60
+ The loss function used for training.
61
+
62
+ - If a str, must be present in the keys of `neurodiffeq.losses._losses`.
63
+ - If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
64
+ - If any other callable, it must map
65
+ A) a residual tensor (shape `(n_points, n_equations)`),
66
+ B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
67
+ C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
68
+ to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
69
+ so that backpropagation can be performed.
70
+
71
+ :type criterion:
72
+ str or `torch.nn.moduesl.loss._Loss` or callable
62
73
:param n_batches_train:
63
74
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
64
75
Defaults to 1.
@@ -687,8 +698,19 @@ class SolverSpherical(BaseSolver):
687
698
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
688
699
:type optimizer: ``torch.nn.optim.Optimizer``, optional
689
700
:param criterion:
690
- Function that maps a PDE residual tensor (of shape (-1, 1)) to a scalar loss.
691
- :type criterion: callable, optional
701
+ The loss function used for training.
702
+
703
+ - If a str, must be present in the keys of `neurodiffeq.losses._losses`.
704
+ - If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
705
+ - If any other callable, it must map
706
+ A) a residual tensor (shape `(n_points, n_equations)`),
707
+ B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
708
+ C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
709
+ to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
710
+ so that backpropagation can be performed.
711
+
712
+ :type criterion:
713
+ str or `torch.nn.moduesl.loss._Loss` or callable
692
714
:param n_batches_train:
693
715
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
694
716
Defaults to 1.
@@ -935,8 +957,19 @@ class Solver1D(BaseSolver):
935
957
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
936
958
:type optimizer: ``torch.nn.optim.Optimizer``, optional
937
959
:param criterion:
938
- Function that maps a ODE residual tensor (of shape (-1, 1)) to a scalar loss.
939
- :type criterion: callable, optional
960
+ The loss function used for training.
961
+
962
+ - If a str, must be present in the keys of `neurodiffeq.losses._losses`.
963
+ - If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
964
+ - If any other callable, it must map
965
+ A) a residual tensor (shape `(n_points, n_equations)`),
966
+ B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
967
+ C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
968
+ to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
969
+ so that backpropagation can be performed.
970
+
971
+ :type criterion:
972
+ str or `torch.nn.moduesl.loss._Loss` or callable
940
973
:param n_batches_train:
941
974
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
942
975
Defaults to 1.
@@ -1108,8 +1141,19 @@ class BundleSolver1D(BaseSolver):
1108
1141
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
1109
1142
:type optimizer: ``torch.nn.optim.Optimizer``, optional
1110
1143
:param criterion:
1111
- Function that maps a ODE residual tensor (of shape (-1, 1)) to a scalar loss.
1112
- :type criterion: callable, optional
1144
+ The loss function used for training.
1145
+
1146
+ - If a str, must be present in the keys of `neurodiffeq.losses._losses`.
1147
+ - If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
1148
+ - If any other callable, it must map
1149
+ A) a residual tensor (shape `(n_points, n_equations)`),
1150
+ B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
1151
+ C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
1152
+ to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
1153
+ so that backpropagation can be performed.
1154
+
1155
+ :type criterion:
1156
+ str or `torch.nn.moduesl.loss._Loss` or callable
1113
1157
:param n_batches_train:
1114
1158
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
1115
1159
Defaults to 1.
@@ -1308,8 +1352,19 @@ class Solver2D(BaseSolver):
1308
1352
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
1309
1353
:type optimizer: ``torch.nn.optim.Optimizer``, optional
1310
1354
:param criterion:
1311
- Function that maps a PDE residual tensor (of shape (-1, 1)) to a scalar loss.
1312
- :type criterion: callable, optional
1355
+ The loss function used for training.
1356
+
1357
+ - If a str, must be present in the keys of `neurodiffeq.losses._losses`.
1358
+ - If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
1359
+ - If any other callable, it must map
1360
+ A) a residual tensor (shape `(n_points, n_equations)`),
1361
+ B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
1362
+ C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
1363
+ to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
1364
+ so that backpropagation can be performed.
1365
+
1366
+ :type criterion:
1367
+ str or `torch.nn.moduesl.loss._Loss` or callable
1313
1368
:param n_batches_train:
1314
1369
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
1315
1370
Defaults to 1.
0 commit comments