Skip to content

Commit 8515d6c

Browse files
authored
Merge pull request #777 from ufal/tf-data-1
Towards TF dataset, part I
2 parents b57d4d3 + 4f0f44f commit 8515d6c

34 files changed

+1308
-1270
lines changed

neuralmonkey/attention/combination.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_attention_states, get_attention_mask, Attendable)
2424
from neuralmonkey.attention.namedtuples import HierarchicalLoopState
2525
from neuralmonkey.checking import assert_shape
26+
from neuralmonkey.decorators import tensor
2627
from neuralmonkey.model.model_part import ModelPart
2728
from neuralmonkey.model.parameterized import InitializerSpecs
2829
from neuralmonkey.tf_utils import get_variable
@@ -49,11 +50,6 @@ def __init__(self,
4950
self._use_sentinels = use_sentinels
5051

5152
self.att_scope_name = "attention_{}".format(name)
52-
53-
with self.use_scope():
54-
self.attn_v = get_variable(
55-
"attn_v", [1, 1, self.attention_state_size],
56-
initializer=tf.random_normal_initializer(stddev=0.001))
5753
# pylint: enable=unused-argument,too-many-arguments
5854

5955
def attention(self,
@@ -64,6 +60,12 @@ def attention(self,
6460
"""Get context vector for given decoder state."""
6561
raise NotImplementedError("Abstract method")
6662

63+
@tensor
64+
def attn_v(self) -> tf.Tensor:
65+
return get_variable(
66+
"attn_v", [1, 1, self.attention_state_size],
67+
initializer=tf.random_normal_initializer(stddev=0.001))
68+
6769
@property
6870
def attn_size(self):
6971
return self.attention_state_size

neuralmonkey/attention/feed_forward.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
BaseAttention, AttentionLoopState, empty_attention_loop_state,
1414
get_attention_states, get_attention_mask, Attendable)
1515
from neuralmonkey.decorators import tensor
16-
from neuralmonkey.logging import log
16+
from neuralmonkey.logging import debug
1717
from neuralmonkey.model.model_part import ModelPart
1818
from neuralmonkey.model.parameterized import InitializerSpecs
1919
from neuralmonkey.nn.utils import dropout
@@ -42,10 +42,6 @@ def __init__(self,
4242

4343
self._variable_scope.set_initializer(
4444
tf.random_normal_initializer(stddev=0.001))
45-
46-
# TODO blessing
47-
log("Hidden features: {}".format(self.hidden_features))
48-
log("Attention mask: {}".format(self.attention_mask))
4945
# pylint: enable=too-many-arguments
5046

5147
@tensor
@@ -170,6 +166,19 @@ def attention(self,
170166
return context, next_loop_state
171167

172168
def initial_loop_state(self) -> AttentionLoopState:
169+
170+
# Here we need to make sure that the hidden_features and attention_mask
171+
# are pre-computed. If this is used in combination with a decoder which
172+
# has train and runtime while loops, these tensors need to be created
173+
# outside of any of those loops in order to be available to both.
174+
175+
# Note that we are not breaking lazy loading here because this method
176+
# is called from a lazy tensor.
177+
178+
debug("Pre-computing attention tensors", "bless")
179+
debug("Hidden features: {}".format(self.hidden_features), "bless")
180+
debug("Hidden mask: {}".format(self.attention_mask), "bless")
181+
173182
return empty_attention_loop_state(
174183
self.batch_size,
175184
tf.shape(self.attention_states)[1],

neuralmonkey/config/normalize.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""Module for configuration normalization.
2+
3+
The `[main]` configuration section contains arguments that can be filled with
4+
different types of values, e.g. `trainer` can be either a single trainer
5+
object or a list of them. This module provides functions for unifying the
6+
configuration interface.
7+
"""
8+
9+
from argparse import Namespace
10+
from datetime import timedelta
11+
import re
12+
import time
13+
from typing import List, Union, Callable
14+
15+
import numpy as np
16+
17+
from neuralmonkey.dataset import BatchingScheme
18+
from neuralmonkey.logging import warn
19+
from neuralmonkey.tf_manager import get_default_tf_manager
20+
from neuralmonkey.trainers.delayed_update_trainer import DelayedUpdateTrainer
21+
22+
23+
def normalize_configuration(cfg: Namespace, train_mode: bool) -> None:
24+
"""Given a configuration namespace, normalize the values it contains.
25+
26+
Arguments:
27+
cfg: The namespace object returned by `Configuration.make_namespace`
28+
train_mode: Boolean flag controlling normalization of parameters only
29+
used during training.
30+
"""
31+
if train_mode:
32+
_normalize_train_cfg(cfg)
33+
34+
if cfg.tf_manager is None:
35+
cfg.tf_manager = get_default_tf_manager()
36+
37+
if (cfg.batch_size is None) == (cfg.batching_scheme is None):
38+
raise ValueError("You must specify either batch_size or "
39+
"batching_scheme (not both).")
40+
41+
if cfg.batch_size is not None:
42+
assert cfg.batching_scheme is None
43+
cfg.batching_scheme = BatchingScheme(batch_size=cfg.batch_size)
44+
else:
45+
assert cfg.batching_scheme is not None
46+
cfg.batch_size = cfg.batching_scheme.batch_size
47+
48+
if cfg.runners_batch_size is None:
49+
cfg.runners_batch_size = cfg.batching_scheme.batch_size
50+
51+
cfg.runners_batching_scheme = BatchingScheme(
52+
batch_size=cfg.runners_batch_size,
53+
token_level_batching=cfg.batching_scheme.token_level_batching,
54+
use_leftover_buckets=True)
55+
56+
cfg.evaluation = [(e[0], e[0], e[1]) if len(e) == 2 else e
57+
for e in cfg.evaluation]
58+
59+
if cfg.evaluation:
60+
cfg.main_metric = "{}/{}".format(cfg.evaluation[-1][0],
61+
cfg.evaluation[-1][-1].name)
62+
else:
63+
cfg.main_metric = "{}/{}".format(cfg.runners[-1].decoder_data_id,
64+
cfg.runners[-1].loss_names[0])
65+
66+
if not cfg.tf_manager.minimize_metric:
67+
raise ValueError("minimize_metric must be set to True in "
68+
"TensorFlowManager when using loss as "
69+
"the main metric")
70+
71+
72+
def _normalize_train_cfg(cfg: Namespace) -> None:
73+
"""Given a configuration namespace, normalize the values it contains.
74+
75+
This function is only executed when training mode has been invoked.
76+
77+
Arguments:
78+
cfg: The namespace object returned by `Configuration.make_namespace`
79+
"""
80+
if not isinstance(cfg.val_dataset, List):
81+
cfg.val_datasets = [cfg.val_dataset]
82+
else:
83+
cfg.val_datasets = cfg.val_dataset
84+
85+
if not isinstance(cfg.trainer, List):
86+
cfg.trainers = [cfg.trainer]
87+
else:
88+
cfg.trainers = cfg.trainer
89+
90+
# deal with delayed trainer and logging periods
91+
# the correct way if there are more trainers is perhaps to do a
92+
# lowest common denominator of their batches_per_update.
93+
# But we can also warn because it is a very weird setup.
94+
95+
delayed_trainers = [t for t in cfg.trainers
96+
if isinstance(t, DelayedUpdateTrainer)]
97+
98+
denominator = 1
99+
if len(cfg.trainers) > 1 and delayed_trainers:
100+
warn("Weird setup: using more trainers and one of them is delayed "
101+
"update trainer. No-one can vouch for your safety, user!")
102+
warn("Using the lowest common denominator of all delayed trainers'"
103+
" batches_per_update parameters for logging period")
104+
warn("Note that if you are using a multi-task trainer, it is on "
105+
"your own risk")
106+
107+
denominator = np.lcm.reduce([t.batches_per_update
108+
for t in delayed_trainers])
109+
elif delayed_trainers:
110+
assert len(cfg.trainers) == 1
111+
denominator = cfg.trainers[0].batches_per_update
112+
113+
cfg.log_timer = _resolve_period(cfg.logging_period, denominator)
114+
cfg.val_timer = _resolve_period(cfg.validation_period, denominator)
115+
116+
117+
def _resolve_period(period: Union[str, int],
118+
denominator: int) -> Callable[[int, float], bool]:
119+
"""Convert logging period into a function for logging time checks.
120+
121+
Logging and validation periods can both be provided either as a number of
122+
batches after which to log/validate, or as a time interval between the
123+
logs/validation runs.
124+
125+
This function unifies both representations into a function that decides
126+
whether to log/validate based on a given training step and time since the
127+
last log/validation.
128+
129+
Arguments:
130+
period: Either a string representing time, or a number representing
131+
number of batches.
132+
denominator: Only allow logging when the given step (number of batches
133+
since the start of the training) is divisible by this value.
134+
This is used e.g. when `DelayedUpdateTrainer` is used.
135+
136+
Returns:
137+
A function of the current training step and time since the last logging
138+
period that returns a boolean value.
139+
"""
140+
def get_batch_logger(period: int) -> Callable[[int, float], bool]:
141+
def is_time(step: int, _: float) -> bool:
142+
return step != 0 and step % period == 0
143+
return is_time
144+
145+
def get_time_logger(period: float) -> Callable[[int, float], bool]:
146+
def is_time(step: int, last_time: float) -> bool:
147+
if step % denominator != 0:
148+
return False
149+
return last_time + period < time.process_time()
150+
return is_time
151+
152+
if isinstance(period, int):
153+
if period % denominator != 0:
154+
raise ValueError(
155+
"When using delayed update trainer, the logging/validation "
156+
"periods must be divisible by batches_per_update.")
157+
158+
return get_batch_logger(period)
159+
160+
regex = re.compile(
161+
r"((?P<days>\d+?)d)?((?P<hours>\d+?)h)?((?P<minutes>\d+?)m)?"
162+
r"((?P<seconds>\d+?)s)?")
163+
parts = regex.match(period)
164+
165+
if not parts:
166+
raise ValueError(
167+
"Validation or logging period have incorrect format. "
168+
"It should be in format: 3h; 5m; 14s")
169+
170+
time_params = {}
171+
for (name, param) in parts.groupdict().items():
172+
if param:
173+
time_params[name] = int(param)
174+
175+
delta_seconds = timedelta(**time_params).total_seconds()
176+
if delta_seconds <= 0:
177+
raise ValueError("Validation or logging period must be bigger than 0")
178+
179+
return get_time_logger(delta_seconds)

neuralmonkey/decoders/autoregressive.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from neuralmonkey.model.feedable import FeedDict
1616
from neuralmonkey.model.parameterized import InitializerSpecs
1717
from neuralmonkey.model.model_part import ModelPart
18-
from neuralmonkey.logging import log, warn
18+
from neuralmonkey.logging import warn
1919
from neuralmonkey.model.sequence import EmbeddedSequence
2020
from neuralmonkey.nn.utils import dropout
2121
from neuralmonkey.tf_utils import get_variable, get_state_shape_invariants
@@ -134,52 +134,62 @@ def __init__(self,
134134
ModelPart.__init__(self, name, reuse, save_checkpoint, load_checkpoint,
135135
initializers)
136136

137-
log("Initializing decoder, name: '{}'".format(name))
138-
139137
self.vocabulary = vocabulary
140138
self.data_id = data_id
141139
self.max_output_len = max_output_len
142140
self.dropout_keep_prob = dropout_keep_prob
143-
self.embedding_size = embedding_size
141+
self._embedding_size = embedding_size
144142
self.embeddings_source = embeddings_source
145143
self.label_smoothing = label_smoothing
146144
self.tie_embeddings = tie_embeddings
147145
self.supress_unk = supress_unk
148146

149-
self.encoder_states = [] # type: List[tf.Tensor]
150-
self.encoder_masks = [] # type: List[tf.Tensor]
147+
self.encoder_states = lambda: [] # type: Callable[[], List[tf.Tensor]]
148+
self.encoder_masks = lambda: [] # type: Callable[[], List[tf.Tensor]]
151149

152150
# Check the values of the parameters (max_output_len, ...)
153-
if max_output_len <= 0:
154-
raise ValueError("Maximum sequence length must be "
155-
"a positive integer.")
151+
if self.max_output_len <= 0:
152+
raise ValueError(
153+
"Maximum sequence length must be a positive integer.")
156154

157-
if dropout_keep_prob < 0.0 or dropout_keep_prob > 1.0:
158-
raise ValueError("Dropout keep probability must be"
159-
"a real number in the interval [0,1].")
155+
if self._embedding_size is not None and self._embedding_size <= 0:
156+
raise ValueError("Embedding size must be a positive integer.")
160157

161-
if self.embedding_size is None and self.embeddings_source is None:
162-
raise ValueError("You must specify either embedding size or the "
163-
"embedded sequence from which to reuse the "
164-
"embeddings (e.g. set either 'embedding_size' or "
165-
" 'embeddings_source' parameter)")
158+
if self.dropout_keep_prob < 0.0 or self.dropout_keep_prob > 1.0:
159+
raise ValueError("Dropout keep probability must be a real number "
160+
"in the interval [0,1].")
161+
# pylint: enable=too-many-arguments,too-many-locals
162+
163+
@property
164+
def embedding_size(self) -> int:
165+
if self.embeddings_source is None:
166+
if self._embedding_size is None:
167+
raise ValueError(
168+
"You must specify either embedding size or the embedded "
169+
"sequence from which to reuse the embeddings (e.g. set "
170+
"'embedding_size' or 'embeddings_source' parameter)")
171+
return self._embedding_size
166172

167173
if self.embeddings_source is not None:
168-
if self.embedding_size is not None:
169-
warn("Overriding the embedding_size parameter with the"
170-
" size of the reused embeddings from the encoder.")
174+
if self._embedding_size is not None:
175+
warn("Overriding the embedding_size parameter with the "
176+
"size of the reused embeddings from the encoder.")
171177

172-
self.embedding_size = (
173-
self.embeddings_source.embedding_matrix.get_shape()[1].value)
178+
return self.embeddings_source.embedding_matrix.get_shape()[1].value
174179

175-
with self.use_scope():
176-
self.go_symbols = tf.placeholder(tf.int32, [None], "go_symbols")
180+
# pylint: disable=no-self-use
181+
@tensor
182+
def go_symbols(self) -> tf.Tensor:
183+
return tf.placeholder(tf.int32, [None], "go_symbols")
177184

178-
self.train_inputs = tf.placeholder(
179-
tf.int32, [None, None], "train_inputs")
180-
self.train_mask = tf.placeholder(
181-
tf.float32, [None, None], "train_mask")
182-
# pylint: enable=too-many-arguments,too-many-locals
185+
@tensor
186+
def train_inputs(self) -> tf.Tensor:
187+
return tf.placeholder(tf.int32, [None, None], "train_inputs")
188+
189+
@tensor
190+
def train_mask(self) -> tf.Tensor:
191+
return tf.placeholder(tf.float32, [None, None], "train_mask")
192+
# pylint: enable=no-self-use
183193

184194
@tensor
185195
def decoding_w(self) -> tf.Variable:

neuralmonkey/decoders/beam_search_decoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,15 @@ def __init__(self,
163163
# the beam. We need to access all the inner states of the network in
164164
# the graph, replace them with beam-size-times copied originals, create
165165
# the beam search graph, and then replace the inner states back.
166+
self._building = False
167+
166168
enc_states = self.parent_decoder.encoder_states
167169
enc_masks = self.parent_decoder.encoder_masks
168170

169171
setattr(self.parent_decoder, "encoder_states",
170-
[self.expand_to_beam(states) for states in enc_states])
172+
lambda: [self.expand_to_beam(sts) for sts in enc_states()])
171173
setattr(self.parent_decoder, "encoder_masks",
172-
[self.expand_to_beam(mask) for mask in enc_masks])
174+
lambda: [self.expand_to_beam(mask) for mask in enc_masks()])
173175

174176
# Create the beam search symbolic graph.
175177
with self.use_scope():

neuralmonkey/decoders/ctc_decoder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from neuralmonkey.dataset import Dataset
88
from neuralmonkey.decorators import tensor
9-
from neuralmonkey.logging import log
109
from neuralmonkey.model.feedable import FeedDict
1110
from neuralmonkey.model.parameterized import InitializerSpecs
1211
from neuralmonkey.model.model_part import ModelPart
@@ -47,7 +46,6 @@ def __init__(self,
4746
self.merge_repeated_targets = merge_repeated_targets
4847
self.merge_repeated_outputs = merge_repeated_outputs
4948
self.beam_width = beam_width
50-
log("CTC output tensor {}.".format(self.decoded))
5149
# pylint: enable=too-many-arguments
5250

5351
# pylint: disable=no-self-use

0 commit comments

Comments
 (0)