Skip to content

Commit b4ea78c

Browse files
authored
Merge pull request #778 from ufal/tf-data-2b
Towards TF dataset, part II
2 parents 8515d6c + c810d0b commit b4ea78c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+3664
-1075
lines changed

neuralmonkey/attention/combination.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from neuralmonkey.attention.namedtuples import HierarchicalLoopState
2525
from neuralmonkey.checking import assert_shape
2626
from neuralmonkey.decorators import tensor
27+
from neuralmonkey.logging import debug
2728
from neuralmonkey.model.model_part import ModelPart
2829
from neuralmonkey.model.parameterized import InitializerSpecs
2930
from neuralmonkey.tf_utils import get_variable
@@ -138,53 +139,66 @@ def __init__(self,
138139
load_checkpoint=load_checkpoint,
139140
initializers=initializers)
140141
self._encoders = encoders
142+
# pylint: enable=too-many-arguments
141143

142-
# pylint: disable=protected-access
143-
self._encoders_tensors = [
144-
get_attention_states(e) for e in self._encoders]
145-
self._encoders_masks = [get_attention_mask(e) for e in self._encoders]
146-
# pylint: enable=protected-access
144+
@tensor
145+
def _encoders_tensors(self) -> List[tf.Tensor]:
146+
tensors = [get_attention_states(e) for e in self._encoders]
147+
for e_t in tensors:
148+
assert_shape(e_t, [-1, -1, -1])
149+
return tensors
147150

148-
for e_m in self._encoders_masks:
151+
@tensor
152+
def _encoders_masks(self) -> List[tf.Tensor]:
153+
masks = [get_attention_mask(e) for e in self._encoders]
154+
for e_m in masks:
149155
assert_shape(e_m, [-1, -1])
150156

151-
for e_t in self._encoders_tensors:
152-
assert_shape(e_t, [-1, -1, -1])
157+
if self._use_sentinels:
158+
masks.append(tf.ones([tf.shape(masks[0])[0], 1]))
159+
return masks
153160

154-
with self.use_scope():
155-
self.encoder_projections_for_logits = \
156-
self.get_encoder_projections("logits_projections")
161+
@tensor
162+
def encoder_projections_for_logits(self) -> List[tf.Tensor]:
163+
return self.get_encoder_projections("logits_projections")
157164

158-
self.encoder_attn_biases = [
159-
get_variable(name="attn_bias_{}".format(i),
160-
shape=[],
165+
@tensor
166+
def encoder_attn_biases(self) -> List[tf.Variable]:
167+
return [get_variable(name="attn_bias_{}".format(i), shape=[],
161168
initializer=tf.zeros_initializer())
162169
for i in range(len(self._encoders_tensors))]
163170

164-
if self._share_projections:
165-
self.encoder_projections_for_ctx = \
166-
self.encoder_projections_for_logits
167-
else:
168-
self.encoder_projections_for_ctx = \
169-
self.get_encoder_projections("context_projections")
170-
171-
if self._use_sentinels:
172-
self._encoders_masks.append(
173-
tf.ones([tf.shape(self._encoders_masks[0])[0], 1]))
171+
@tensor
172+
def encoder_projections_for_ctx(self) -> List[tf.Tensor]:
173+
if self._share_projections:
174+
return self.encoder_projections_for_logits
175+
return self.get_encoder_projections("context_projections")
174176

175-
self.masks_concat = tf.concat(self._encoders_masks, 1)
176-
# pylint: enable=too-many-arguments
177+
@tensor
178+
def masks_concat(self) -> tf.Tensor:
179+
return tf.concat(self._encoders_masks, 1)
177180

178181
def initial_loop_state(self) -> AttentionLoopState:
179182

183+
# Similarly to the feed_forward attention, we need to build the encoder
184+
# projections and masks before the while loop is entered so they are
185+
# not created as a part of the loop
186+
187+
# pylint: disable=not-an-iterable
188+
for val in self.encoder_projections_for_logits:
189+
debug(val, "bless")
190+
debug(self.masks_concat, "bless")
191+
180192
length = sum(tf.shape(s)[1] for s in self._encoders_tensors)
193+
# pylint: enable=not-an-iterable
194+
181195
if self._use_sentinels:
182196
length += 1
183197

184198
return empty_attention_loop_state(self.batch_size, length,
185199
self.context_vector_size)
186200

187-
def get_encoder_projections(self, scope):
201+
def get_encoder_projections(self, scope: str) -> List[tf.Tensor]:
188202
encoder_projections = []
189203
with tf.variable_scope(scope):
190204
for i, encoder_tensor in enumerate(self._encoders_tensors):
@@ -216,9 +230,11 @@ def get_encoder_projections(self, scope):
216230
encoder_projections.append(projection)
217231
return encoder_projections
218232

233+
# pylint: disable=unsubscriptable-object
219234
@property
220235
def context_vector_size(self) -> int:
221236
return self.encoder_projections_for_ctx[0].get_shape()[2].value
237+
# pylint: enable=unsubscriptable-object
222238

223239
# pylint: disable=too-many-locals
224240
def attention(self,
@@ -280,6 +296,7 @@ def attention(self,
280296
return contexts, next_loop_state
281297
# pylint: enable=too-many-locals
282298

299+
# pylint: disable=not-an-iterable,unsubscriptable-object
283300
def _tile_encoders_for_beamsearch(self, projected_sentinel):
284301
sentinel_batch_size = tf.shape(projected_sentinel)[0]
285302
encoders_batch_size = tf.shape(
@@ -293,6 +310,7 @@ def _tile_encoders_for_beamsearch(self, projected_sentinel):
293310

294311
return [tf.tile(proj, [beam_size, 1, 1])
295312
for proj in self.encoder_projections_for_ctx]
313+
# pylint: enable=not-an-iterable,unsubscriptable-object
296314

297315
def _renorm_softmax(self, logits):
298316
"""Renormalized softmax wrt. attention mask."""

neuralmonkey/attention/coverage.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
from neuralmonkey.attention.base_attention import Attendable
1212
from neuralmonkey.attention.feed_forward import Attention
13+
from neuralmonkey.decorators import tensor
1314
from neuralmonkey.model.model_part import ModelPart
1415
from neuralmonkey.model.parameterized import InitializerSpecs
16+
from neuralmonkey.tf_utils import get_variable
1517

1618

1719
class CoverageAttention(Attention):
@@ -30,17 +32,22 @@ def __init__(self,
3032
Attention.__init__(self, name, encoder, dropout_keep_prob, state_size,
3133
reuse, save_checkpoint, load_checkpoint,
3234
initializers)
33-
3435
self.max_fertility = max_fertility
36+
# pylint: enable=too-many-arguments
3537

36-
self.coverage_weights = tf.get_variable(
37-
"coverage_matrix", [1, 1, 1, self.state_size])
38-
self.fertility_weights = tf.get_variable(
38+
@tensor
39+
def coverage_weights(self) -> tf.Variable:
40+
return get_variable("coverage_matrix", [1, 1, 1, self.state_size])
41+
42+
@tensor
43+
def fertility_weights(self) -> tf.Variable:
44+
return get_variable(
3945
"fertility_matrix", [1, 1, self.context_vector_size])
4046

41-
self.fertility = 1e-8 + self.max_fertility * tf.sigmoid(
47+
@tensor
48+
def fertility(self) -> tf.Tensor:
49+
return 1e-8 + self.max_fertility * tf.sigmoid(
4250
tf.reduce_sum(self.fertility_weights * self.attention_states, [2]))
43-
# pylint: enable=too-many-arguments
4451

4552
def get_energies(self, y: tf.Tensor, weights_in_time: tf.Tensor):
4653
weight_sum = tf.cond(

neuralmonkey/attention/scaled_dot_product.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import tensorflow as tf
1313
from typeguard import check_argument_types
1414

15-
from neuralmonkey.nn.utils import dropout
16-
from neuralmonkey.model.model_part import ModelPart
17-
from neuralmonkey.model.parameterized import InitializerSpecs
1815
from neuralmonkey.attention.base_attention import (
1916
BaseAttention, Attendable, get_attention_states, get_attention_mask)
2017
from neuralmonkey.attention.namedtuples import MultiHeadLoopState
18+
from neuralmonkey.decorators import tensor
19+
from neuralmonkey.model.model_part import ModelPart
20+
from neuralmonkey.model.parameterized import InitializerSpecs
21+
from neuralmonkey.nn.utils import dropout
2122

2223

2324
def split_for_heads(x: tf.Tensor, n_heads: int, head_dim: int) -> tf.Tensor:
@@ -263,23 +264,35 @@ def __init__(self,
263264
self.n_heads = n_heads
264265
self.dropout_keep_prob = dropout_keep_prob
265266

267+
self.keys_encoder = keys_encoder
268+
269+
if values_encoder is not None:
270+
self.values_encoder = values_encoder
271+
else:
272+
self.values_encoder = self.keys_encoder
273+
266274
if self.n_heads <= 0:
267275
raise ValueError("Number of heads must be greater than zero.")
268276

269277
if self.dropout_keep_prob <= 0.0 or self.dropout_keep_prob > 1.0:
270278
raise ValueError("Dropout keep prob must be inside (0,1].")
271279

272-
if values_encoder is None:
273-
values_encoder = keys_encoder
274-
275-
self.attention_keys = get_attention_states(keys_encoder)
276-
self.attention_mask = get_attention_mask(keys_encoder)
277-
self.attention_values = get_attention_states(values_encoder)
278-
279280
self._variable_scope.set_initializer(tf.variance_scaling_initializer(
280281
mode="fan_avg", distribution="uniform"))
281282
# pylint: enable=too-many-arguments
282283

284+
@tensor
285+
def attention_keys(self) -> tf.Tensor:
286+
return get_attention_states(self.keys_encoder)
287+
288+
@tensor
289+
def attention_mask(self) -> tf.Tensor:
290+
return get_attention_mask(self.keys_encoder)
291+
292+
@tensor
293+
def attention_values(self) -> tf.Tensor:
294+
return get_attention_states(self.values_encoder)
295+
283296
def attention(self,
284297
query: tf.Tensor,
285298
decoder_prev_state: tf.Tensor,
@@ -346,9 +359,11 @@ def finalize_loop(self, key: str,
346359
head_weights = last_loop_state.head_weights[i]
347360
self.histories["{}_head{}".format(key, i)] = head_weights
348361

362+
# pylint: disable=no-member
349363
@property
350364
def context_vector_size(self) -> int:
351365
return self.attention_values.get_shape()[-1].value
366+
# pylint: enable=no-member
352367

353368
def visualize_attention(self, key: str, max_outputs: int = 16) -> None:
354369
for i in range(self.n_heads):

neuralmonkey/checking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def check_dataset_and_coders(dataset: Dataset,
4747
missing = []
4848

4949
for (serie, coder) in data_list:
50-
if not dataset.has_series(serie):
50+
if serie not in dataset:
5151
log("dataset {} does not have serie {}".format(
5252
dataset.name, serie))
5353
missing.append((coder, serie))

neuralmonkey/dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,7 @@ def __len__(self) -> int:
467467
assert self.length is not None
468468
return self.length
469469

470-
@property
471-
def series(self) -> List[str]:
472-
return list(sorted(self.iterators.keys()))
473-
474-
def has_series(self, name: str) -> bool:
470+
def __contains__(self, name: str) -> bool:
475471
"""Check if the dataset contains a series of a given name.
476472
477473
Arguments:
@@ -482,6 +478,10 @@ def has_series(self, name: str) -> bool:
482478
"""
483479
return name in self.iterators
484480

481+
@property
482+
def series(self) -> List[str]:
483+
return list(sorted(self.iterators.keys()))
484+
485485
def get_series(self, name: str) -> Iterator:
486486
"""Get the data series with a given name.
487487

neuralmonkey/decoders/autoregressive.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
The autoregressive decoder uses the while loop to get the outputs.
66
Descendants should only specify the initial state and the while loop body.
77
"""
8-
from typing import NamedTuple, Callable, Tuple, Optional, Any, List
8+
from typing import NamedTuple, Callable, Tuple, Optional, Any, List, Dict
99

1010
import numpy as np
1111
import tensorflow as tf
@@ -19,7 +19,9 @@
1919
from neuralmonkey.model.sequence import EmbeddedSequence
2020
from neuralmonkey.nn.utils import dropout
2121
from neuralmonkey.tf_utils import get_variable, get_state_shape_invariants
22-
from neuralmonkey.vocabulary import Vocabulary, START_TOKEN, UNK_TOKEN_INDEX
22+
from neuralmonkey.vocabulary import (
23+
Vocabulary, START_TOKEN, UNK_TOKEN_INDEX, START_TOKEN_INDEX,
24+
PAD_TOKEN_INDEX)
2325

2426

2527
class LoopState(NamedTuple(
@@ -177,19 +179,25 @@ def embedding_size(self) -> int:
177179

178180
return self.embeddings_source.embedding_matrix.get_shape()[1].value
179181

180-
# pylint: disable=no-self-use
181182
@tensor
182183
def go_symbols(self) -> tf.Tensor:
183-
return tf.placeholder(tf.int32, [None], "go_symbols")
184+
return tf.fill([self.batch_size], START_TOKEN_INDEX)
185+
186+
@property
187+
def input_types(self) -> Dict[str, tf.DType]:
188+
return {self.data_id: tf.int32}
189+
190+
@property
191+
def input_shapes(self) -> Dict[str, tf.TensorShape]:
192+
return {self.data_id: tf.TensorShape([None, None])}
184193

185194
@tensor
186195
def train_inputs(self) -> tf.Tensor:
187-
return tf.placeholder(tf.int32, [None, None], "train_inputs")
196+
return self.dataset[self.data_id]
188197

189198
@tensor
190199
def train_mask(self) -> tf.Tensor:
191-
return tf.placeholder(tf.float32, [None, None], "train_mask")
192-
# pylint: enable=no-self-use
200+
return tf.to_float(tf.not_equal(self.train_inputs, PAD_TOKEN_INDEX))
193201

194202
@tensor
195203
def decoding_w(self) -> tf.Variable:
@@ -479,12 +487,11 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
479487
if sentences is not None:
480488
sentences_list = list(sentences)
481489
# train_mode=False, since we don't want to <unk>ize target words!
482-
inputs, weights = self.vocabulary.sentences_to_tensor(
490+
inputs, _ = self.vocabulary.sentences_to_tensor(
483491
sentences_list, self.max_output_len, train_mode=False,
484492
add_start_symbol=False, add_end_symbol=True,
485493
pad_to_max_len=False)
486494

487495
fd[self.train_inputs] = inputs
488-
fd[self.train_mask] = weights
489496

490497
return fd

neuralmonkey/decoders/beam_search_decoder.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
"""
2222
# pylint: disable=too-many-lines
2323
# Maybe move the definitions of the named tuple structures to a separate file?
24-
from typing import NamedTuple, List, Callable, Any
24+
from typing import Any, Callable, List, NamedTuple
25+
# pylint: disable=unused-import
26+
from typing import Optional
27+
# pylint: enable=unused-import
2528

2629
import tensorflow as tf
2730
from typeguard import check_argument_types
@@ -157,13 +160,16 @@ def __init__(self,
157160
# Create a placeholder for maximum number of steps that is necessary
158161
# during ensembling, when the decoder is called repetitively with the
159162
# max_steps attribute set to one.
160-
self.max_steps = tf.placeholder_with_default(max_steps, [])
163+
self.max_steps = tf.placeholder_with_default(self.max_steps_int, [])
161164

165+
self._initial_loop_state = None # type: Optional[BeamSearchLoopState]
166+
167+
@tensor
168+
def outputs(self) -> tf.Tensor:
162169
# This is an ugly hack for handling the whole graph when expanding to
163170
# the beam. We need to access all the inner states of the network in
164171
# the graph, replace them with beam-size-times copied originals, create
165172
# the beam search graph, and then replace the inner states back.
166-
self._building = False
167173

168174
enc_states = self.parent_decoder.encoder_states
169175
enc_masks = self.parent_decoder.encoder_masks
@@ -175,13 +181,21 @@ def __init__(self,
175181

176182
# Create the beam search symbolic graph.
177183
with self.use_scope():
178-
self.initial_loop_state = self.get_initial_loop_state()
179-
self.outputs = self.decoding_loop()
184+
self._initial_loop_state = self.get_initial_loop_state()
185+
outputs = self.decoding_loop()
180186

181187
# Reassign the original encoder states and mask back
182188
setattr(self.parent_decoder, "encoder_states", enc_states)
183189
setattr(self.parent_decoder, "encoder_masks", enc_masks)
184190

191+
return outputs
192+
193+
@property
194+
def initial_loop_state(self) -> BeamSearchLoopState:
195+
if self._initial_loop_state is None:
196+
raise RuntimeError("Initial loop state was not initialized")
197+
return self._initial_loop_state
198+
185199
@property
186200
def vocabulary(self) -> Vocabulary:
187201
return self.parent_decoder.vocabulary

0 commit comments

Comments
 (0)