|
14 | 14 |
|
15 | 15 |
|
16 | 16 | class Objective(Generic[MP]):
|
17 |
| - """The training objective. |
18 |
| -
|
19 |
| - Attributes: |
20 |
| - name: The name for the objective. Used in TensorBoard. |
21 |
| - decoder: The decoder which generates the value to optimize. |
22 |
| - loss: The loss tensor fetched by the trainer. |
23 |
| - gradients: Manually specified gradients. Useful for reinforcement |
24 |
| - learning. |
25 |
| - weight: The weight of this objective. The loss will be multiplied by |
26 |
| - this so the gradients can be controled in case of multiple |
27 |
| - objectives. |
28 |
| - """ |
| 17 | + """The training objective base class.""" |
29 | 18 |
|
30 | 19 | def __init__(self, name: str, decoder: MP) -> None:
|
| 20 | + """Construct the objective. |
| 21 | +
|
| 22 | + Arguments: |
| 23 | + name: The name for the objective. This will be used e.g. in |
| 24 | + TensorBoard. |
| 25 | + """ |
31 | 26 | self._name = name
|
32 | 27 | self._decoder = decoder
|
33 | 28 |
|
34 | 29 | @property
|
35 | 30 | def decoder(self) -> MP:
|
| 31 | + """Get the decoder used by the objective.""" |
36 | 32 | return self._decoder
|
37 | 33 |
|
38 | 34 | @property
|
39 | 35 | def name(self) -> str:
|
| 36 | + """Get the name of the objective.""" |
40 | 37 | return self._name
|
41 | 38 |
|
42 | 39 | @abstractproperty
|
43 | 40 | def loss(self) -> tf.Tensor:
|
| 41 | + """Return the loss tensor fetched by the trainer.""" |
44 | 42 | raise NotImplementedError()
|
45 | 43 |
|
46 | 44 | @property
|
47 | 45 | def gradients(self) -> Optional[Gradients]:
|
| 46 | + """Manually specified gradients - useful for reinforcement learning.""" |
48 | 47 | return None
|
49 | 48 |
|
50 | 49 | @property
|
51 | 50 | def weight(self) -> Optional[tf.Tensor]:
|
| 51 | + """Return the weight of this objective. |
| 52 | +
|
| 53 | + The loss will be multiplied by this so the gradients can be controlled |
| 54 | + in case of multiple objectives. |
| 55 | +
|
| 56 | + Returns: |
| 57 | + An optional tensor. If None, default weight of 1 is assumed. |
| 58 | + """ |
52 | 59 | return None
|
53 | 60 |
|
54 | 61 |
|
55 | 62 | class CostObjective(Objective[GenericModelPart]):
|
| 63 | + """Cost objective class. |
| 64 | +
|
| 65 | + This class represent objectives that are based directly on a `cost` |
| 66 | + attribute of any compatible model part. |
| 67 | + """ |
56 | 68 |
|
57 | 69 | def __init__(self, decoder: GenericModelPart,
|
58 | 70 | weight: ObjectiveWeight = None) -> None:
|
| 71 | + """Construct a new instance of the `CostObjective` class. |
| 72 | +
|
| 73 | + Arguments: |
| 74 | + decoder: A `GenericModelPart` instance that has a `cost` attribute. |
| 75 | + weight: The weight of the objective. |
| 76 | +
|
| 77 | + Raises: |
| 78 | + `TypeError` when the decoder argument does not have the `cost` |
| 79 | + attribute. |
| 80 | + """ |
59 | 81 | check_argument_types()
|
| 82 | + if "cost" not in dir(decoder): |
| 83 | + raise TypeError("The decoder does not have the 'cost' attribute") |
60 | 84 |
|
61 | 85 | name = "{} - cost".format(str(decoder))
|
| 86 | + |
62 | 87 | Objective[GenericModelPart].__init__(self, name, decoder)
|
63 | 88 | self._weight = weight
|
64 | 89 |
|
65 | 90 | @tensor
|
66 | 91 | def loss(self) -> tf.Tensor:
|
67 |
| - if "cost" not in dir(self.decoder): |
68 |
| - raise TypeError("The decoder does not have the 'cost' attribute") |
69 |
| - |
70 | 92 | return getattr(self.decoder, "cost")
|
71 | 93 |
|
72 | 94 | @tensor
|
|
0 commit comments