Skip to content

Commit f2204cd

Browse files
committed
Moving creation of data_id-related placeholders to feedable class
Plus, adding a pseudo-dummy register_input function which constructs the placeholders.
1 parent d7e2044 commit f2204cd

16 files changed

+299
-168
lines changed

neuralmonkey/decoders/autoregressive.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
The autoregressive decoder uses the while loop to get the outputs.
66
Descendants should only specify the initial state and the while loop body.
77
"""
8-
from typing import NamedTuple, Callable, Tuple, Optional, Any, List
8+
from typing import NamedTuple, Callable, Tuple, Optional, Any, List, Dict
99

1010
import numpy as np
1111
import tensorflow as tf
@@ -19,7 +19,9 @@
1919
from neuralmonkey.model.sequence import EmbeddedSequence
2020
from neuralmonkey.nn.utils import dropout
2121
from neuralmonkey.tf_utils import get_variable, get_state_shape_invariants
22-
from neuralmonkey.vocabulary import Vocabulary, START_TOKEN, UNK_TOKEN_INDEX
22+
from neuralmonkey.vocabulary import (
23+
Vocabulary, START_TOKEN, UNK_TOKEN_INDEX, START_TOKEN_INDEX,
24+
PAD_TOKEN_INDEX)
2325

2426

2527
class LoopState(NamedTuple(
@@ -177,19 +179,25 @@ def embedding_size(self) -> int:
177179

178180
return self.embeddings_source.embedding_matrix.get_shape()[1].value
179181

180-
# pylint: disable=no-self-use
181182
@tensor
182183
def go_symbols(self) -> tf.Tensor:
183-
return tf.placeholder(tf.int32, [None], "go_symbols")
184+
return tf.fill([self.batch_size], START_TOKEN_INDEX)
185+
186+
@property
187+
def input_types(self) -> Dict[str, tf.DType]:
188+
return {self.data_id: tf.int32}
189+
190+
@property
191+
def input_shapes(self) -> Dict[str, tf.TensorShape]:
192+
return {self.data_id: tf.TensorShape([None, None])}
184193

185194
@tensor
186195
def train_inputs(self) -> tf.Tensor:
187-
return tf.placeholder(tf.int32, [None, None], "train_inputs")
196+
return self.dataset[self.data_id]
188197

189198
@tensor
190199
def train_mask(self) -> tf.Tensor:
191-
return tf.placeholder(tf.float32, [None, None], "train_mask")
192-
# pylint: enable=no-self-use
200+
return tf.to_float(tf.not_equal(self.train_inputs, PAD_TOKEN_INDEX))
193201

194202
@tensor
195203
def decoding_w(self) -> tf.Variable:
@@ -479,12 +487,11 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
479487
if sentences is not None:
480488
sentences_list = list(sentences)
481489
# train_mode=False, since we don't want to <unk>ize target words!
482-
inputs, weights = self.vocabulary.sentences_to_tensor(
490+
inputs, _ = self.vocabulary.sentences_to_tensor(
483491
sentences_list, self.max_output_len, train_mode=False,
484492
add_start_symbol=False, add_end_symbol=True,
485493
pad_to_max_len=False)
486494

487495
fd[self.train_inputs] = inputs
488-
fd[self.train_mask] = weights
489496

490497
return fd

neuralmonkey/decoders/beam_search_decoder.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
"""
2222
# pylint: disable=too-many-lines
2323
# Maybe move the definitions of the named tuple structures to a separate file?
24-
from typing import NamedTuple, List, Callable, Any
24+
from typing import Any, Callable, List, NamedTuple
25+
# pylint: disable=unused-import
26+
from typing import Optional
27+
# pylint: enable=unused-import
2528

2629
import tensorflow as tf
2730
from typeguard import check_argument_types
@@ -159,11 +162,14 @@ def __init__(self,
159162
# max_steps attribute set to one.
160163
self.max_steps = tf.placeholder_with_default(self.max_steps_int, [])
161164

165+
self._initial_loop_state = None # type: Optional[BeamSearchLoopState]
166+
167+
@tensor
168+
def outputs(self) -> tf.Tensor:
162169
# This is an ugly hack for handling the whole graph when expanding to
163170
# the beam. We need to access all the inner states of the network in
164171
# the graph, replace them with beam-size-times copied originals, create
165172
# the beam search graph, and then replace the inner states back.
166-
self._building = False
167173

168174
enc_states = self.parent_decoder.encoder_states
169175
enc_masks = self.parent_decoder.encoder_masks
@@ -175,13 +181,21 @@ def __init__(self,
175181

176182
# Create the beam search symbolic graph.
177183
with self.use_scope():
178-
self.initial_loop_state = self.get_initial_loop_state()
179-
self.outputs = self.decoding_loop()
184+
self._initial_loop_state = self.get_initial_loop_state()
185+
outputs = self.decoding_loop()
180186

181187
# Reassign the original encoder states and mask back
182188
setattr(self.parent_decoder, "encoder_states", enc_states)
183189
setattr(self.parent_decoder, "encoder_masks", enc_masks)
184190

191+
return outputs
192+
193+
@property
194+
def initial_loop_state(self) -> BeamSearchLoopState:
195+
if self._initial_loop_state is None:
196+
raise RuntimeError("Initial loop state was not initialized")
197+
return self._initial_loop_state
198+
185199
@property
186200
def vocabulary(self) -> Vocabulary:
187201
return self.parent_decoder.vocabulary

neuralmonkey/decoders/classifier.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List
1+
from typing import Callable, Dict, List
22

33
import tensorflow as tf
44
from typeguard import check_argument_types
@@ -62,11 +62,17 @@ def __init__(self,
6262
self.max_output_len = 1
6363
# pylint: enable=too-many-arguments
6464

65-
# pylint: disable=no-self-use
65+
@property
66+
def input_types(self) -> Dict[str, tf.DType]:
67+
return {self.data_id: tf.int32}
68+
69+
@property
70+
def input_shapes(self) -> Dict[str, tf.TensorShape]:
71+
return {self.data_id: tf.TensorShape([None])}
72+
6673
@tensor
6774
def gt_inputs(self) -> tf.Tensor:
68-
return tf.placeholder(tf.int32, [None], "targets")
69-
# pylint: enable=no-self-use
75+
return self.dataset[self.data_id]
7076

7177
@tensor
7278
def _mlp(self) -> MultilayerPerceptron:

neuralmonkey/decoders/sequence_labeler.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Dict, Union
22

33
import tensorflow as tf
44
from typeguard import check_argument_types
@@ -11,7 +11,7 @@
1111
from neuralmonkey.model.parameterized import InitializerSpecs
1212
from neuralmonkey.model.model_part import ModelPart
1313
from neuralmonkey.tf_utils import get_variable
14-
from neuralmonkey.vocabulary import Vocabulary
14+
from neuralmonkey.vocabulary import Vocabulary, PAD_TOKEN_INDEX
1515

1616

1717
class SequenceLabeler(ModelPart):
@@ -38,15 +38,21 @@ def __init__(self,
3838
self.dropout_keep_prob = dropout_keep_prob
3939
# pylint: enable=too-many-arguments
4040

41-
# pylint: disable=no-self-use
41+
@property
42+
def input_types(self) -> Dict[str, tf.DType]:
43+
return {self.data_id: tf.int32}
44+
45+
@property
46+
def input_shapes(self) -> Dict[str, tf.TensorShape]:
47+
return {self.data_id: tf.TensorShape([None, None])}
48+
4249
@tensor
4350
def train_targets(self) -> tf.Tensor:
44-
return tf.placeholder(tf.int32, [None, None], "targets")
51+
return self.dataset[self.data_id]
4552

4653
@tensor
47-
def train_weights(self) -> tf.Tensor:
48-
return tf.placeholder(tf.float32, [None, None], "padding")
49-
# pylint: enable=no-self-use
54+
def train_mask(self) -> tf.Tensor:
55+
return tf.to_float(tf.not_equal(self.train_targets, PAD_TOKEN_INDEX))
5056

5157
@property
5258
def rnn_size(self) -> int:
@@ -116,7 +122,7 @@ def cost(self) -> tf.Tensor:
116122

117123
# loss is now of shape [batch, time]. Need to mask it now by
118124
# element-wise multiplication with weights placeholder
119-
weighted_loss = loss * self.train_weights
125+
weighted_loss = loss * self.train_mask
120126
return tf.reduce_sum(weighted_loss)
121127

122128
@property
@@ -132,10 +138,8 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
132138

133139
sentences = dataset.maybe_get_series(self.data_id)
134140
if sentences is not None:
135-
vectors, paddings = self.vocabulary.sentences_to_tensor(
141+
vectors, _ = self.vocabulary.sentences_to_tensor(
136142
list(sentences), pad_to_max_len=False, train_mode=train)
137143

138144
fd[self.train_targets] = vectors.T
139-
fd[self.train_weights] = paddings.T
140-
141145
return fd

neuralmonkey/decoders/sequence_regressor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List
1+
from typing import Callable, Dict, List
22

33
import tensorflow as tf
44
from typeguard import check_argument_types
@@ -46,11 +46,17 @@ def __init__(self,
4646
self._dropout_keep_prob = dropout_keep_prob
4747
# pylint: enable=too-many-arguments
4848

49-
# pylint: disable=no-self-use
49+
@property
50+
def input_types(self) -> Dict[str, tf.DType]:
51+
return {self.data_id: tf.float32}
52+
53+
@property
54+
def input_shapes(self) -> Dict[str, tf.TensorShape]:
55+
return {self.data_id: tf.TensorShape([None])}
56+
5057
@tensor
5158
def train_inputs(self) -> tf.Tensor:
52-
return tf.placeholder(tf.float32, [None], "targets")
53-
# pylint: enable=no-self-use
59+
return self.dataset[self.data_id]
5460

5561
@tensor
5662
def _mlp_input(self):

neuralmonkey/decoders/word_alignment_decoder.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import cast, Tuple
1+
# TODO untested module
2+
from typing import cast, Dict, Tuple
23

34
import numpy as np
45
import tensorflow as tf
@@ -42,14 +43,18 @@ def enc_input(self) -> Sequence:
4243

4344
return cast(Sequence, self.encoder.input_sequence)
4445

46+
@property
47+
def input_types(self) -> Dict[str, tf.DType]:
48+
return {self.data_id: tf.float32}
49+
50+
@property
51+
def input_shapes(self) -> Dict[str, tf.TensorShape]:
52+
return {self.data_id: tf.TensorShape(
53+
[None, self.decoder.max_output_len, self.enc_input.max_length])}
54+
4555
@tensor
4656
def ref_alignment(self) -> tf.Tensor:
47-
# TODO dynamic shape?
48-
return tf.placeholder(
49-
dtype=tf.float32,
50-
shape=[None, self.decoder.max_output_len,
51-
self.enc_input.max_length],
52-
name="ref_alignment")
57+
return self.dataset[self.data_id]
5358

5459
@tensor
5560
def alignment_target(self) -> tf.Tensor:

neuralmonkey/encoders/cnn_encoder.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""CNN for image processing."""
22

3-
from typing import cast, Callable, List, Tuple, Union
3+
from typing import cast, Callable, Dict, List, Tuple, Union
44
from typeguard import check_argument_types
55

66
import numpy as np
@@ -83,20 +83,24 @@ def __init__(self,
8383
self.batch_normalize = batch_normalize
8484
# pylint: enable=too-many-arguments, too-many-locals
8585

86+
@property
87+
def input_types(self) -> Dict[str, tf.DType]:
88+
return {self.data_id: tf.float32}
89+
90+
@property
91+
def input_shapes(self) -> Dict[str, tf.TensorShape]:
92+
return {self.data_id: tf.TensorShape(
93+
[None, self.image_height, self.image_width, self.pixel_dim])}
94+
8695
@tensor
8796
def image_input(self) -> tf.Tensor:
88-
return tf.placeholder(
89-
tf.float32,
90-
shape=(None, self.image_height, self.image_width,
91-
self.pixel_dim),
92-
name="input_images")
97+
return self.dataset[self.data_id]
9398

9499
@tensor
95100
def image_mask(self) -> tf.Tensor:
96-
return tf.placeholder(
97-
tf.float32,
98-
shape=(None, self.image_height, self.image_width, 1),
99-
name="input_mask")
101+
# the image mask is one everywhere where the image is non-zero, i.e.
102+
# zero pixels are masked out
103+
return tf.sign(tf.reduce_sum(self.image_input, axis=3, keepdims=True))
100104

101105
def batch_norm_callback(self, layer_output: tf.Tensor) -> tf.Tensor:
102106
if self.batch_normalize:
@@ -198,13 +202,7 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
198202
# if it is from the pickled file, it is a list, not a numpy tensor,
199203
# so convert it as as a prevention
200204
images = np.array(list(dataset.get_series(self.data_id)))
201-
202205
fd[self.image_input] = images / 255.0
203-
204-
# the image mask is one everywhere where the image is non-zero, i.e.
205-
# zero pixels are masked out
206-
fd[self.image_mask] = np.sign(np.sum(images, axis=3, keepdims=True))
207-
208206
return fd
209207

210208

neuralmonkey/encoders/imagenet_encoder.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""Pre-trained ImageNet networks."""
22

3-
from typing import Callable, NamedTuple, Tuple, Optional, Any
43
import sys
4+
from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple
55

6-
from typeguard import check_argument_types
76
import numpy as np
87
import tensorflow as tf
98
import tensorflow.contrib.slim as tf_slim
@@ -12,6 +11,7 @@
1211
# see https://github.com/tensorflow/tensorflow/issues/6064
1312
import tensorflow.contrib.slim.nets
1413
# pylint: enable=unused-import
14+
from typeguard import check_argument_types
1515

1616
from neuralmonkey.dataset import Dataset
1717
from neuralmonkey.decorators import tensor
@@ -158,6 +158,19 @@ def __init__(self,
158158
self.net_specification = SUPPORTED_NETWORKS[self.network_type]()
159159
self.height, self.width = self.net_specification.image_size
160160

161+
@property
162+
def input_types(self) -> Dict[str, tf.DType]:
163+
return {self.data_id: tf.float32}
164+
165+
@property
166+
def input_shapes(self) -> Dict[str, tf.TensorShape]:
167+
return {
168+
self.data_id: tf.TensorShape([None, self.height, self.width, 3])}
169+
170+
@tensor
171+
def input_image(self) -> tf.Tensor:
172+
return self.dataset[self.data_id]
173+
161174
@tensor
162175
def end_points(self) -> Any:
163176
with tf_slim.arg_scope(self.net_specification.scope()):
@@ -187,11 +200,6 @@ def end_points(self) -> Any:
187200

188201
return end_points
189202

190-
@tensor
191-
def input_image(self) -> tf.Tensor:
192-
return tf.placeholder(
193-
tf.float32, [None, self.height, self.width, 3])
194-
195203
@tensor
196204
def spatial_states(self) -> Optional[tf.Tensor]:
197205
if self.spatial_layer is None:

0 commit comments

Comments
 (0)