Skip to content

Commit 24df477

Browse files
committed
Lazifying most of the model parts
1 parent 3b876eb commit 24df477

17 files changed

+215
-143
lines changed

neuralmonkey/attention/combination.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from neuralmonkey.attention.namedtuples import HierarchicalLoopState
2525
from neuralmonkey.checking import assert_shape
2626
from neuralmonkey.decorators import tensor
27+
from neuralmonkey.logging import debug
2728
from neuralmonkey.model.model_part import ModelPart
2829
from neuralmonkey.model.parameterized import InitializerSpecs
2930
from neuralmonkey.tf_utils import get_variable
@@ -138,53 +139,63 @@ def __init__(self,
138139
load_checkpoint=load_checkpoint,
139140
initializers=initializers)
140141
self._encoders = encoders
142+
# pylint: enable=too-many-arguments
141143

142-
# pylint: disable=protected-access
143-
self._encoders_tensors = [
144-
get_attention_states(e) for e in self._encoders]
145-
self._encoders_masks = [get_attention_mask(e) for e in self._encoders]
146-
# pylint: enable=protected-access
144+
@tensor
145+
def _encoders_tensors(self) -> List[tf.Tensor]:
146+
tensors = [get_attention_states(e) for e in self._encoders]
147+
for e_t in tensors:
148+
assert_shape(e_t, [-1, -1, -1])
149+
return tensors
147150

148-
for e_m in self._encoders_masks:
151+
@tensor
152+
def _encoders_masks(self) -> List[tf.Tensor]:
153+
masks = [get_attention_mask(e) for e in self._encoders]
154+
for e_m in masks:
149155
assert_shape(e_m, [-1, -1])
150156

151-
for e_t in self._encoders_tensors:
152-
assert_shape(e_t, [-1, -1, -1])
157+
if self._use_sentinels:
158+
masks.append(tf.ones([tf.shape(masks[0])[0], 1]))
159+
return masks
153160

154-
with self.use_scope():
155-
self.encoder_projections_for_logits = \
156-
self.get_encoder_projections("logits_projections")
161+
@tensor
162+
def encoder_projections_for_logits(self) -> List[tf.Tensor]:
163+
return self.get_encoder_projections("logits_projections")
157164

158-
self.encoder_attn_biases = [
159-
get_variable(name="attn_bias_{}".format(i),
160-
shape=[],
165+
@tensor
166+
def encoder_attn_biases(self) -> List[tf.Variable]:
167+
return [get_variable(name="attn_bias_{}".format(i), shape=[],
161168
initializer=tf.zeros_initializer())
162169
for i in range(len(self._encoders_tensors))]
163170

164-
if self._share_projections:
165-
self.encoder_projections_for_ctx = \
166-
self.encoder_projections_for_logits
167-
else:
168-
self.encoder_projections_for_ctx = \
169-
self.get_encoder_projections("context_projections")
170-
171-
if self._use_sentinels:
172-
self._encoders_masks.append(
173-
tf.ones([tf.shape(self._encoders_masks[0])[0], 1]))
171+
@tensor
172+
def encoder_projections_for_ctx(self) -> List[tf.Tensor]:
173+
if self._share_projections:
174+
return self.encoder_projections_for_logits
175+
return self.get_encoder_projections("context_projections")
174176

175-
self.masks_concat = tf.concat(self._encoders_masks, 1)
176-
# pylint: enable=too-many-arguments
177+
@tensor
178+
def masks_concat(self) -> tf.Tensor:
179+
return tf.concat(self._encoders_masks, 1)
177180

178181
def initial_loop_state(self) -> AttentionLoopState:
179182

183+
# pylint: disable=not-an-iterable
184+
# TODO blessing
185+
for val in self.encoder_projections_for_logits:
186+
debug(val)
187+
debug(self.masks_concat)
188+
180189
length = sum(tf.shape(s)[1] for s in self._encoders_tensors)
190+
# pylint: enable=not-an-iterable
191+
181192
if self._use_sentinels:
182193
length += 1
183194

184195
return empty_attention_loop_state(self.batch_size, length,
185196
self.context_vector_size)
186197

187-
def get_encoder_projections(self, scope):
198+
def get_encoder_projections(self, scope) -> List[tf.Tensor]:
188199
encoder_projections = []
189200
with tf.variable_scope(scope):
190201
for i, encoder_tensor in enumerate(self._encoders_tensors):
@@ -216,9 +227,11 @@ def get_encoder_projections(self, scope):
216227
encoder_projections.append(projection)
217228
return encoder_projections
218229

230+
# pylint: disable=unsubscriptable-object
219231
@property
220232
def context_vector_size(self) -> int:
221233
return self.encoder_projections_for_ctx[0].get_shape()[2].value
234+
# pylint: enable=unsubscriptable-object
222235

223236
# pylint: disable=too-many-locals
224237
def attention(self,
@@ -280,6 +293,7 @@ def attention(self,
280293
return contexts, next_loop_state
281294
# pylint: enable=too-many-locals
282295

296+
# pylint: disable=not-an-iterable,unsubscriptable-object
283297
def _tile_encoders_for_beamsearch(self, projected_sentinel):
284298
sentinel_batch_size = tf.shape(projected_sentinel)[0]
285299
encoders_batch_size = tf.shape(
@@ -293,6 +307,7 @@ def _tile_encoders_for_beamsearch(self, projected_sentinel):
293307

294308
return [tf.tile(proj, [beam_size, 1, 1])
295309
for proj in self.encoder_projections_for_ctx]
310+
# pylint: enable=not-an-iterable,unsubscriptable-object
296311

297312
def _renorm_softmax(self, logits):
298313
"""Renormalized softmax wrt. attention mask."""

neuralmonkey/attention/coverage.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
from neuralmonkey.attention.base_attention import Attendable
1212
from neuralmonkey.attention.feed_forward import Attention
13+
from neuralmonkey.decorators import tensor
1314
from neuralmonkey.model.model_part import ModelPart
1415
from neuralmonkey.model.parameterized import InitializerSpecs
16+
from neuralmonkey.tf_utils import get_variable
1517

1618

1719
class CoverageAttention(Attention):
@@ -30,17 +32,22 @@ def __init__(self,
3032
Attention.__init__(self, name, encoder, dropout_keep_prob, state_size,
3133
reuse, save_checkpoint, load_checkpoint,
3234
initializers)
33-
3435
self.max_fertility = max_fertility
36+
# pylint: enable=too-many-arguments
3537

36-
self.coverage_weights = tf.get_variable(
37-
"coverage_matrix", [1, 1, 1, self.state_size])
38-
self.fertility_weights = tf.get_variable(
38+
@tensor
39+
def coverage_weights(self) -> tf.Variable:
40+
return get_variable("coverage_matrix", [1, 1, 1, self.state_size])
41+
42+
@tensor
43+
def fertility_weights(self) -> tf.Variable:
44+
return get_variable(
3945
"fertility_matrix", [1, 1, self.context_vector_size])
4046

41-
self.fertility = 1e-8 + self.max_fertility * tf.sigmoid(
47+
@tensor
48+
def fertility(self) -> tf.Tensor:
49+
return 1e-8 + self.max_fertility * tf.sigmoid(
4250
tf.reduce_sum(self.fertility_weights * self.attention_states, [2]))
43-
# pylint: enable=too-many-arguments
4451

4552
def get_energies(self, y: tf.Tensor, weights_in_time: tf.Tensor):
4653
weight_sum = tf.cond(

neuralmonkey/attention/scaled_dot_product.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import tensorflow as tf
1313
from typeguard import check_argument_types
1414

15-
from neuralmonkey.nn.utils import dropout
16-
from neuralmonkey.model.model_part import ModelPart
17-
from neuralmonkey.model.parameterized import InitializerSpecs
1815
from neuralmonkey.attention.base_attention import (
1916
BaseAttention, Attendable, get_attention_states, get_attention_mask)
2017
from neuralmonkey.attention.namedtuples import MultiHeadLoopState
18+
from neuralmonkey.decorators import tensor
19+
from neuralmonkey.model.model_part import ModelPart
20+
from neuralmonkey.model.parameterized import InitializerSpecs
21+
from neuralmonkey.nn.utils import dropout
2122

2223

2324
def split_for_heads(x: tf.Tensor, n_heads: int, head_dim: int) -> tf.Tensor:
@@ -263,23 +264,35 @@ def __init__(self,
263264
self.n_heads = n_heads
264265
self.dropout_keep_prob = dropout_keep_prob
265266

267+
self.keys_encoder = keys_encoder
268+
269+
if values_encoder is not None:
270+
self.values_encoder = values_encoder
271+
else:
272+
self.values_encoder = self.keys_encoder
273+
266274
if self.n_heads <= 0:
267275
raise ValueError("Number of heads must be greater than zero.")
268276

269277
if self.dropout_keep_prob <= 0.0 or self.dropout_keep_prob > 1.0:
270278
raise ValueError("Dropout keep prob must be inside (0,1].")
271279

272-
if values_encoder is None:
273-
values_encoder = keys_encoder
274-
275-
self.attention_keys = get_attention_states(keys_encoder)
276-
self.attention_mask = get_attention_mask(keys_encoder)
277-
self.attention_values = get_attention_states(values_encoder)
278-
279280
self._variable_scope.set_initializer(tf.variance_scaling_initializer(
280281
mode="fan_avg", distribution="uniform"))
281282
# pylint: enable=too-many-arguments
282283

284+
@tensor
285+
def attention_keys(self) -> tf.Tensor:
286+
return get_attention_states(self.keys_encoder)
287+
288+
@tensor
289+
def attention_mask(self) -> tf.Tensor:
290+
return get_attention_mask(self.keys_encoder)
291+
292+
@tensor
293+
def attention_values(self) -> tf.Tensor:
294+
return get_attention_states(self.values_encoder)
295+
283296
def attention(self,
284297
query: tf.Tensor,
285298
decoder_prev_state: tf.Tensor,
@@ -346,9 +359,11 @@ def finalize_loop(self, key: str,
346359
head_weights = last_loop_state.head_weights[i]
347360
self.histories["{}_head{}".format(key, i)] = head_weights
348361

362+
# pylint: disable=no-member
349363
@property
350364
def context_vector_size(self) -> int:
351365
return self.attention_values.get_shape()[-1].value
366+
# pylint: enable=no-member
352367

353368
def visualize_attention(self, key: str, max_outputs: int = 16) -> None:
354369
for i in range(self.n_heads):

neuralmonkey/decoders/beam_search_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(self,
157157
# Create a placeholder for maximum number of steps that is necessary
158158
# during ensembling, when the decoder is called repetitively with the
159159
# max_steps attribute set to one.
160-
self.max_steps = tf.placeholder_with_default(max_steps, [])
160+
self.max_steps = tf.placeholder_with_default(self.max_steps_int, [])
161161

162162
# This is an ugly hack for handling the whole graph when expanding to
163163
# the beam. We need to access all the inner states of the network in

neuralmonkey/decoders/classifier.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from typeguard import check_argument_types
55

66
from neuralmonkey.dataset import Dataset
7-
from neuralmonkey.vocabulary import Vocabulary
7+
from neuralmonkey.decorators import tensor
88
from neuralmonkey.model.feedable import FeedDict
99
from neuralmonkey.model.parameterized import InitializerSpecs
1010
from neuralmonkey.model.model_part import ModelPart
1111
from neuralmonkey.model.stateful import Stateful
1212
from neuralmonkey.nn.mlp import MultilayerPerceptron
13-
from neuralmonkey.decorators import tensor
13+
from neuralmonkey.vocabulary import Vocabulary
1414

1515

1616
class Classifier(ModelPart):
@@ -60,35 +60,42 @@ def __init__(self,
6060
self.activation_fn = activation_fn
6161
self.dropout_keep_prob = dropout_keep_prob
6262
self.max_output_len = 1
63-
64-
with self.use_scope():
65-
self.gt_inputs = [tf.placeholder(tf.int32, [None], "targets")]
66-
67-
mlp_input = tf.concat([enc.output for enc in self.encoders], 1)
68-
self._mlp = MultilayerPerceptron(
69-
mlp_input, self.layers,
70-
self.dropout_keep_prob, len(self.vocabulary),
71-
activation_fn=self.activation_fn, train_mode=self.train_mode)
72-
73-
tf.summary.scalar(
74-
"train_optimization_cost",
75-
self.cost, collections=["summary_train"])
7663
# pylint: enable=too-many-arguments
7764

65+
# pylint: disable=no-self-use
7866
@tensor
79-
def loss_with_gt_ins(self) -> tf.Tensor:
80-
return tf.reduce_mean(
81-
tf.nn.sparse_softmax_cross_entropy_with_logits(
82-
logits=self._mlp.logits, labels=self.gt_inputs[0]))
67+
def gt_inputs(self) -> tf.Tensor:
68+
return tf.placeholder(tf.int32, [None], "targets")
69+
# pylint: enable=no-self-use
70+
71+
@tensor
72+
def _mlp(self) -> MultilayerPerceptron:
73+
mlp_input = tf.concat([enc.output for enc in self.encoders], 1)
74+
return MultilayerPerceptron(
75+
mlp_input, self.layers, self.dropout_keep_prob,
76+
len(self.vocabulary), activation_fn=self.activation_fn,
77+
train_mode=self.train_mode)
8378

8479
@property
8580
def loss_with_decoded_ins(self) -> tf.Tensor:
8681
return self.loss_with_gt_ins
8782

8883
@property
8984
def cost(self) -> tf.Tensor:
85+
tf.summary.scalar(
86+
"train_optimization_cost",
87+
self.loss_with_gt_ins, collections=["summary_train"])
88+
9089
return self.loss_with_gt_ins
9190

91+
# pylint: disable=no-member
92+
# this is for the _mlp attribute (pylint property bug)
93+
@tensor
94+
def loss_with_gt_ins(self) -> tf.Tensor:
95+
return tf.reduce_mean(
96+
tf.nn.sparse_softmax_cross_entropy_with_logits(
97+
logits=self._mlp.logits, labels=self.gt_inputs))
98+
9299
@tensor
93100
def decoded_seq(self) -> tf.Tensor:
94101
return tf.expand_dims(self._mlp.classification, 0)
@@ -100,6 +107,7 @@ def decoded_logits(self) -> tf.Tensor:
100107
@tensor
101108
def runtime_logprobs(self) -> tf.Tensor:
102109
return tf.expand_dims(tf.nn.log_softmax(self._mlp.logits), 0)
110+
# pylint: enable=no-member
103111

104112
@property
105113
def train_loss(self):
@@ -120,6 +128,6 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
120128
if sentences is not None:
121129
label_tensors, _ = self.vocabulary.sentences_to_tensor(
122130
list(sentences), self.max_output_len)
123-
fd[self.gt_inputs[0]] = label_tensors[0]
131+
fd[self.gt_inputs] = label_tensors[0]
124132

125133
return fd

neuralmonkey/decoders/sequence_labeler.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
from typeguard import check_argument_types
55

66
from neuralmonkey.dataset import Dataset
7+
from neuralmonkey.decorators import tensor
8+
from neuralmonkey.encoders.recurrent import RecurrentEncoder
9+
from neuralmonkey.encoders.facebook_conv import SentenceEncoder
710
from neuralmonkey.model.feedable import FeedDict
811
from neuralmonkey.model.parameterized import InitializerSpecs
912
from neuralmonkey.model.model_part import ModelPart
10-
from neuralmonkey.encoders.recurrent import RecurrentEncoder
11-
from neuralmonkey.encoders.facebook_conv import SentenceEncoder
12-
from neuralmonkey.vocabulary import Vocabulary
13-
from neuralmonkey.decorators import tensor
1413
from neuralmonkey.tf_utils import get_variable
14+
from neuralmonkey.vocabulary import Vocabulary
1515

1616

1717
class SequenceLabeler(ModelPart):
@@ -36,15 +36,21 @@ def __init__(self,
3636
self.vocabulary = vocabulary
3737
self.data_id = data_id
3838
self.dropout_keep_prob = dropout_keep_prob
39+
# pylint: enable=too-many-arguments
3940

40-
self.rnn_size = int(self.encoder.temporal_states.get_shape()[-1])
41+
# pylint: disable=no-self-use
42+
@tensor
43+
def train_targets(self) -> tf.Tensor:
44+
return tf.placeholder(tf.int32, [None, None], "targets")
4145

42-
with self.use_scope():
43-
self.train_targets = tf.placeholder(
44-
tf.int32, [None, None], "labeler_targets")
45-
self.train_weights = tf.placeholder(
46-
tf.float32, [None, None], "labeler_padding_weights")
47-
# pylint: enable=too-many-arguments
46+
@tensor
47+
def train_weights(self) -> tf.Tensor:
48+
return tf.placeholder(tf.float32, [None, None], "padding")
49+
# pylint: enable=no-self-use
50+
51+
@property
52+
def rnn_size(self) -> int:
53+
return int(self.encoder.temporal_states.get_shape()[-1])
4854

4955
@tensor
5056
def decoding_w(self) -> tf.Variable:

0 commit comments

Comments
 (0)