Skip to content

Commit b48e287

Browse files
committed
Refactoring Objective tuple to a class with lazy properties
1 parent c7f1878 commit b48e287

12 files changed

+381
-272
lines changed

neuralmonkey/trainers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from .cross_entropy_trainer import CrossEntropyTrainer
2+
from .delayed_update_trainer import DelayedUpdateTrainer
3+
from .multitask_trainer import MultitaskTrainer

neuralmonkey/trainers/cross_entropy_trainer.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,21 @@
33
import tensorflow as tf
44
from typeguard import check_argument_types
55

6-
from neuralmonkey.trainers.generic_trainer import (
7-
GenericTrainer, Objective, ObjectiveWeight)
6+
from neuralmonkey.logging import warn
7+
from neuralmonkey.trainers.generic_trainer import GenericTrainer
8+
from neuralmonkey.trainers.objective import (
9+
Objective, CostObjective, ObjectiveWeight)
810

911

12+
# for compatibility reasons
1013
def xent_objective(decoder, weight=None) -> Objective:
1114
"""Get XENT objective from decoder with cost."""
12-
return Objective(
13-
name="{} - cross-entropy".format(decoder.name),
14-
decoder=decoder,
15-
loss=decoder.cost,
16-
gradients=None,
17-
weight=weight,
18-
)
19-
20-
# pylint: disable=too-few-public-methods,too-many-arguments
15+
warn("Using deprecated xent_objective function. Use the CostObjective "
16+
"class directly.")
17+
return CostObjective(decoder, weight)
2118

2219

20+
# pylint: disable=too-many-arguments
2321
class CrossEntropyTrainer(GenericTrainer):
2422

2523
def __init__(self,
@@ -41,7 +39,7 @@ def __init__(self,
4139
"decoder_weights (length {}) do not match decoders (length {})"
4240
.format(len(decoder_weights), len(decoders)))
4341

44-
objectives = [xent_objective(dec, w)
42+
objectives = [CostObjective(dec, w)
4543
for dec, w in zip(decoders, decoder_weights)]
4644

4745
GenericTrainer.__init__(

neuralmonkey/trainers/generic_trainer.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,18 @@
1-
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
1+
from typing import Dict, List, Optional, Tuple, Sequence
22
import re
33

44
import tensorflow as tf
55
from typeguard import check_argument_types
66

77
from neuralmonkey.decorators import tensor
8-
from neuralmonkey.model.model_part import GenericModelPart
8+
from neuralmonkey.logging import warn
99
from neuralmonkey.runners.base_runner import GraphExecutor, NextExecute
10-
11-
# pylint: disable=invalid-name
12-
Gradients = List[Tuple[tf.Tensor, tf.Variable]]
13-
ObjectiveWeight = Union[tf.Tensor, float, None]
14-
# pylint: enable=invalid-name
10+
from neuralmonkey.trainers.objective import (
11+
Objective, Gradients, ObjectiveWeight)
1512

1613
BIAS_REGEX = re.compile(r"[Bb]ias")
1714

1815

19-
class Objective(NamedTuple(
20-
"Objective",
21-
[("name", str),
22-
("decoder", GenericModelPart),
23-
("loss", tf.Tensor),
24-
("gradients", Optional[Gradients]),
25-
("weight", ObjectiveWeight)])):
26-
"""The training objective.
27-
28-
Attributes:
29-
name: The name for the objective. Used in TensorBoard.
30-
decoder: The decoder which generates the value to optimize.
31-
loss: The loss tensor fetched by the trainer.
32-
gradients: Manually specified gradients. Useful for reinforcement
33-
learning.
34-
weight: The weight of this objective. The loss will be multiplied by
35-
this so the gradients can be controled in case of multiple
36-
objectives.
37-
"""
38-
39-
4016
# pylint: disable=too-few-public-methods,too-many-locals,too-many-arguments
4117
class GenericTrainer(GraphExecutor):
4218

@@ -78,7 +54,7 @@ def default_optimizer():
7854
return tf.train.AdamOptimizer(learning_rate=1e-4)
7955

8056
def __init__(self,
81-
objectives: List[Objective],
57+
objectives: Sequence[Objective],
8258
l1_weight: float = 0.0,
8359
l2_weight: float = 0.0,
8460
clip_norm: float = None,
@@ -110,6 +86,10 @@ def regularization_losses(self) -> Tuple[tf.Tensor, tf.Tensor]:
11086
and not v.name.startswith("Inception")
11187
and not v.name.startswith("resnet")]
11288

89+
if not regularizable:
90+
warn("It seems that there are no trainable variables in the model")
91+
return tf.zeros([]), tf.zeros([])
92+
11393
with tf.name_scope("regularization"):
11494
l1_norm = sum(tf.reduce_sum(abs(v)) for v in regularizable)
11595
l2_norm = sum(tf.reduce_sum(v ** 2) for v in regularizable)
@@ -120,11 +100,15 @@ def regularization_losses(self) -> Tuple[tf.Tensor, tf.Tensor]:
120100
@tensor
121101
def objective_values(self) -> List[tf.Tensor]:
122102
"""Compute unweighted losses for fetching."""
103+
# Note here we need to call the losses first, in case the model is
104+
# being built. We need to compute the regularizers after that.
105+
losses = [o.loss for o in self.objectives]
106+
123107
# pylint: disable=unpacking-non-sequence
124108
l1_norm, l2_norm = self.regularization_losses
125109
# pylint: disable=unpacking-non-sequence
126110

127-
return [o.loss for o in self.objectives] + [l1_norm, l2_norm]
111+
return losses + [l1_norm, l2_norm]
128112

129113
@tensor
130114
def differentiable_loss_sum(self) -> tf.Tensor:

neuralmonkey/trainers/multitask_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def __init__(self,
2525
self.trainers = trainers
2626
self.trainer_idx = 0
2727

28-
self.var_list = list(set.union(*[set(t.var_list) for t in trainers]))
28+
@property
29+
def var_list(self) -> List[tf.Variable]:
30+
return list(set.union(*[set(t.var_list) for t in self.trainers]))
2931

3032
def get_executable(
3133
self, compute_losses: bool = True, summaries: bool = True,

neuralmonkey/trainers/objective.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from abc import abstractproperty
2+
from typing import TypeVar, Union, Tuple, List, Optional, Generic
3+
import tensorflow as tf
4+
from typeguard import check_argument_types
5+
6+
from neuralmonkey.decorators import tensor
7+
from neuralmonkey.model.model_part import GenericModelPart
8+
9+
# pylint: disable=invalid-name
10+
ObjectiveWeight = Union[tf.Tensor, float, None]
11+
Gradients = List[Tuple[tf.Tensor, tf.Variable]]
12+
MP = TypeVar("MP", bound=GenericModelPart)
13+
# pylint: enable=invalid-name
14+
15+
16+
class Objective(Generic[MP]):
17+
"""The training objective.
18+
19+
Attributes:
20+
name: The name for the objective. Used in TensorBoard.
21+
decoder: The decoder which generates the value to optimize.
22+
loss: The loss tensor fetched by the trainer.
23+
gradients: Manually specified gradients. Useful for reinforcement
24+
learning.
25+
weight: The weight of this objective. The loss will be multiplied by
26+
this so the gradients can be controled in case of multiple
27+
objectives.
28+
"""
29+
30+
def __init__(self, name: str, decoder: MP) -> None:
31+
self._name = name
32+
self._decoder = decoder
33+
34+
@property
35+
def decoder(self) -> MP:
36+
return self._decoder
37+
38+
@property
39+
def name(self) -> str:
40+
return self._name
41+
42+
@abstractproperty
43+
def loss(self) -> tf.Tensor:
44+
raise NotImplementedError()
45+
46+
@property
47+
def gradients(self) -> Optional[Gradients]:
48+
return None
49+
50+
@property
51+
def weight(self) -> Optional[tf.Tensor]:
52+
return None
53+
54+
55+
class CostObjective(Objective[GenericModelPart]):
56+
57+
def __init__(self, decoder: GenericModelPart,
58+
weight: ObjectiveWeight = None) -> None:
59+
check_argument_types()
60+
61+
if not hasattr(decoder, "cost"):
62+
raise TypeError("The decoder does not have a `cost` attribute")
63+
64+
name = "{} - cost".format(str(decoder))
65+
Objective[GenericModelPart].__init__(self, name, decoder)
66+
self._weight = weight
67+
68+
@tensor
69+
def loss(self) -> tf.Tensor:
70+
assert hasattr(self.decoder, "cost")
71+
return getattr(self.decoder, "cost")
72+
73+
@tensor
74+
def weight(self) -> Optional[tf.Tensor]:
75+
if self._weight is None:
76+
return None
77+
78+
if isinstance(self._weight, float):
79+
return tf.constant(self._weight)
80+
81+
return self._weight

0 commit comments

Comments
 (0)