Skip to content

Commit ef855bd

Browse files
authored
Merge pull request #349 from ufal/tf1.0
Introducing TensorFlow 1.0 branch
2 parents e5ab2ea + 0850cd3 commit ef855bd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+324
-433
lines changed

.readthedocs-conda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ dependencies:
1212
- numpy
1313
- pillow
1414
- git+https://github.com/aflc/pyter@857a1552443f139a3
15-
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.11.0-cp35-cp35m-linux_x86_64.whl
15+
- tensorflow
1616
- sphinx==1.5.1

.travis.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ dist: trusty
44
language: python
55

66
env:
7-
global:
8-
- TF=0.11.0-cp35-cp35m
97
matrix:
108
- TEST_SUITE=lint
119
- TEST_SUITE=pycodestyle
@@ -23,8 +21,9 @@ python:
2321

2422
# commands to install dependencies
2523
before_install:
24+
- sudo apt-get install libtcmalloc-minimal4
25+
- export LD_PRELOAD="/usr/lib/libtcmalloc_minimal.so.4"
2626
- pip install -r requirements.txt
27-
- pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-$TF""-linux_x86_64.whl
2827
- if [ -f tests/$TEST_SUITE""_requirements.txt ]; then pip install -r tests/$TEST_SUITE""_requirements.txt; fi
2928
- if [ -f tests/$TEST_SUITE""_install.sh ]; then tests/$TEST_SUITE""_install.sh; fi
3029

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
'sphinx.ext.todo',
3535
'sphinx.ext.coverage',
3636
'sphinx.ext.pngmath',
37+
'sphinx.ext.intersphinx'
3738
]
3839

3940
# Add any paths that contain templates here, relative to this directory.

docs/source/install.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,10 @@ and CuDNN installations. Similarly, your ``PATH`` variable should point to the
4141
``bin`` subdirectory of the CUDA installation directory.
4242

4343
You made it! Neural Monkey is now installed!
44+
45+
Note for Ubuntu 14.04 users
46+
***************************
47+
48+
If you get Segmentation fault errors at the very end of the training process,
49+
you can either ignore it, or follow the steps outlined in `this
50+
document <ubuntu1404_fix.html>`_.

docs/source/machine_translation.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ The following sections are described in more detail in
238238
class=tf_manager.TensorFlowManager
239239
num_threads=4
240240
num_sessions=1
241+
minimize_metric=False
241242
save_n_best=3
242243
.. TUTCHECK exp-nm-mt/translation.ini
243244
@@ -254,7 +255,6 @@ As for the main configuration section do not forget to add BPE postprocessing:
254255
train_dataset=<train_data>
255256
val_dataset=<val_data>
256257
evaluation=[("series_named_greedy", "target", <bleu>), ("series_named_greedy", "target", evaluators.ter.TER)]
257-
minimize=False
258258
batch_size=80
259259
runners_batch_size=256
260260
epochs=10
@@ -277,7 +277,7 @@ As for the evaluation, you need to create ``translation_run.ini``:
277277
278278
[main]
279279
test_datasets=[<eval_data>]
280-
280+
281281
[bpe_preprocess]
282282
class=processors.bpe.BPEPreprocessor
283283
merge_file="exp-nm-mt/data/merge_file.bpe"

docs/source/tutorial.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ TensorFlow should use, you need to specify a "TensorFlow manager":
458458
class=tf_manager.TensorFlowManager
459459
num_threads=4
460460
num_sessions=1
461+
minimize_metric=True
461462
save_n_best=3
462463
.. TUTCHECK exp-nm-ape/post-edit.ini
463464
@@ -480,7 +481,6 @@ parameters:
480481
train_dataset=<train_dataset>
481482
val_dataset=<val_dataset>
482483
evaluation=[("greedy_edits", "edits", <bleu>), ("greedy_edits", "edits", evaluators.ter.TER)]
483-
minimize=True
484484
batch_size=128
485485
runners_batch_size=256
486486
epochs=100

docs/source/ubuntu1404_fix.rst

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
Fixing segmentation fault on exit on Ubuntu 14.04
2+
=================================================
3+
4+
* On Ufal machines, segfault can be prevented by doing this:
5+
6+
.. code-block:: bash
7+
8+
export LD_PRELOAD=/home/helcl/lib/libtcmalloc_minimal.so.4
9+
bin/neuralmonkey-train tests/vocab.ini
10+
11+
* On machines with ``sudo``, one can do this:
12+
13+
.. code-block:: bash
14+
15+
sudo apt-get install libtcmalloc-minimal4
16+
export LD_PRELOAD="/usr/lib/libtcmalloc_minimal.so.4"
17+
18+
* On machines with neither ``sudo`` nor
19+
``~helcl/lib/libtcmalloc_minimal.so.4``, this is the way to fix segfaulting:
20+
21+
.. code-block:: bash
22+
23+
wget http://archive.ubuntu.com/ubuntu/pool/main/g/google-perftools/google-perftools_2.1.orig.tar.gz
24+
tar xpzvf google-perftools_2.1.orig.tar.gz
25+
cd gperftools-2.1/
26+
./configure --prefix=$HOME
27+
make
28+
make install
29+
30+
if the compilation crashes on the need of the ``libunwind`` library (as did for
31+
me), do this:
32+
33+
.. code-block:: bash
34+
35+
wget http://download.savannah.gnu.org/releases/libunwind/libunwind-0.99-beta.tar.gz
36+
tar xpzvf libunwind-0.99-beta.tar.gz
37+
cd libunwind-0.99-beta/
38+
./configure --prefix=$HOME
39+
make
40+
make install
41+
42+
if, by any chance, compilation of this crashes on something like: ``error:
43+
'longjmp' aliased to undefined symbol '_longjmp'``, replace the ``make`` call
44+
with ``make CFLAGS+=-U_FORTIFY_SOURCE`` command.
45+
46+
Then, in ``$HOME/share`` directory, create file ``config.site`` like this:
47+
48+
.. code-block:: bash
49+
50+
cat << EOF > $HOME/share/config.site
51+
CPPFLAGS=-I$HOME/include
52+
LDFLAGS=-L$HOME/lib
53+
EOF
54+
55+
and then redo the configure-make-make install mantra from gperftools. Finally,
56+
set the ``LD_PRELOAD`` environment variable to point to
57+
``$HOME/lib/libtcmalloc_minimal.4.so``.

neuralmonkey/decoders/decoder.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,11 @@ def decode(rnn_outputs):
204204

205205
_, self.train_logits = decode(train_rnn_outputs)
206206

207-
train_targets = tf.unpack(self.train_inputs)
207+
train_targets = tf.transpose(self.train_inputs)
208208

209-
self.train_loss = tf.nn.seq2seq.sequence_loss(
210-
self.train_logits, train_targets,
211-
tf.unpack(self.train_padding), len(self.vocabulary))
209+
self.train_loss = tf.contrib.seq2seq.sequence_loss(
210+
tf.stack(self.train_logits, 1), train_targets,
211+
tf.transpose(self.train_padding))
212212
self.cost = self.train_loss
213213

214214
self.train_logprobs = [tf.nn.log_softmax(l)
@@ -217,9 +217,9 @@ def decode(rnn_outputs):
217217
self.decoded, self.runtime_logits = decode(
218218
self.runtime_rnn_outputs)
219219

220-
self.runtime_loss = tf.nn.seq2seq.sequence_loss(
221-
self.runtime_logits, train_targets,
222-
tf.unpack(self.train_padding), len(self.vocabulary))
220+
self.runtime_loss = tf.contrib.seq2seq.sequence_loss(
221+
tf.stack(self.runtime_logits, 1), train_targets,
222+
tf.transpose(self.train_padding))
223223

224224
self.runtime_logprobs = [tf.nn.log_softmax(l)
225225
for l in self.runtime_logits]
@@ -306,11 +306,11 @@ def _logit_function(self, state: tf.Tensor) -> tf.Tensor:
306306
state = dropout(state, self.dropout_keep_prob, self.train_mode)
307307
return tf.matmul(state, self.decoding_w) + self.decoding_b
308308

309-
def _get_rnn_cell(self) -> tf.nn.rnn_cell.RNNCell:
309+
def _get_rnn_cell(self) -> tf.contrib.rnn.RNNCell:
310310
if self._rnn_cell == 'GRU':
311-
return tf.nn.rnn_cell.GRUCell(self.rnn_size)
311+
return tf.contrib.rnn.GRUCell(self.rnn_size)
312312
elif self._rnn_cell == 'LSTM':
313-
return tf.nn.rnn_cell.LSTMCell(self.rnn_size)
313+
return tf.contrib.rnn.LSTMCell(self.rnn_size)
314314
else:
315315
raise ValueError("Unknown RNN cell: {}".format(self._rnn_cell))
316316

@@ -355,7 +355,7 @@ def _attention_decoder(
355355
state = self.initial_state
356356
elif self._rnn_cell == 'LSTM':
357357
# pylint: disable=redefined-variable-type
358-
state = tf.nn.rnn_cell.LSTMStateTuple(
358+
state = tf.contrib.rnn.LSTMStateTuple(
359359
self.initial_state, self.initial_state)
360360
# pylint: enable=redefined-variable-type
361361
else:
@@ -423,12 +423,12 @@ def _visualize_attention(self):
423423

424424
for i, a in enumerate(att_objects):
425425
alignments = tf.expand_dims(tf.transpose(
426-
tf.pack(a.attentions_in_time), perm=[1, 2, 0]), -1)
426+
tf.stack(a.attentions_in_time), perm=[1, 2, 0]), -1)
427427

428-
tf.image_summary(
428+
tf.summary.image(
429429
"attention_{}".format(i), alignments,
430430
collections=["summary_val_plots"],
431-
max_images=256)
431+
max_outputs=256)
432432

433433
def feed_dict(self, dataset: Dataset, train: bool=False) -> FeedDict:
434434
"""Populate the feed dictionary for the decoder object

neuralmonkey/decoders/encoder_projection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def func(train_mode: tf.Tensor,
6262
" of encoder projection")
6363

6464
with tf.variable_scope("encoders_projection") as scope:
65-
encoded_concat = tf.concat(1, [e.encoded for e in encoders])
65+
encoded_concat = tf.concat([e.encoded for e in encoders], 1)
6666
encoded_concat = dropout(
6767
encoded_concat, dropout_keep_prob, train_mode)
6868

@@ -90,7 +90,7 @@ def concat_encoder_projection(
9090
assert rnn_size == sum(e.encoded.get_shape()[1].value
9191
for e in encoders)
9292

93-
encoded_concat = tf.concat(1, [e.encoded for e in encoders])
93+
encoded_concat = tf.concat([e.encoded for e in encoders], 1)
9494

9595
# pylint: disable=no-member
9696
log("The inferred rnn_size of this encoder projection will be {}"

neuralmonkey/decoders/multi_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def __init__(self, main_decoder, regularization_decoders):
4949
self.regularization_decoders = regularization_decoders
5050

5151
self._training_decoders = [main_decoder] + regularization_decoders
52-
self._decoder_costs = tf.concat(0, [tf.expand_dims(d.cost, 0)
53-
for d in self._training_decoders])
52+
self._decoder_costs = tf.concat([tf.expand_dims(d.cost, 0)
53+
for d in self._training_decoders], 0)
5454

5555
self._scheduled_decoder = 0
5656
self._input_selector = tf.placeholder(tf.float32,

neuralmonkey/decoders/output_projection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def no_deep_output(prev_state, prev_output, ctx_tensors):
2424
Returns:
2525
This function returns the concatenation of all its inputs.
2626
"""
27-
return tf.concat(1, [prev_state, prev_output] + ctx_tensors)
27+
return tf.concat([prev_state, prev_output] + ctx_tensors, 1)
2828

2929

3030
def maxout_output(maxout_size):
@@ -63,7 +63,7 @@ def mlp_output(layer_sizes, dropout_plc=None, activation=tf.tanh):
6363
activation: The activation function to use in each layer.
6464
"""
6565
def _projection(prev_state, prev_output, ctx_tensors):
66-
mlp_input = tf.concat(1, [prev_state, prev_output] + ctx_tensors)
66+
mlp_input = tf.concat([prev_state, prev_output] + ctx_tensors, 1)
6767

6868
return multilayer_projection(mlp_input, layer_sizes,
6969
activation=activation,

neuralmonkey/decoders/sequence_classifier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,25 @@ def __init__(self,
6262
tf.placeholder(tf.float32, name="dropout_plc")
6363
self.gt_inputs = [tf.placeholder(
6464
tf.int32, shape=[None], name="targets")]
65-
mlp_input = tf.concat(1, [enc.encoded for enc in encoders])
65+
mlp_input = tf.concat([enc.encoded for enc in encoders], 1)
6666
mlp = MultilayerPerceptron(
6767
mlp_input, layers, self.dropout_placeholder, len(vocabulary),
6868
activation_fn=self.activation_fn)
6969

7070
self.loss_with_gt_ins = tf.reduce_mean(
7171
tf.nn.sparse_softmax_cross_entropy_with_logits(
72-
mlp.logits, self.gt_inputs[0]))
72+
logits=mlp.logits, labels=self.gt_inputs[0]))
7373
self.loss_with_decoded_ins = self.loss_with_gt_ins
7474
self.cost = self.loss_with_gt_ins
7575

7676
self.decoded_seq = [mlp.classification]
7777
self.decoded_logits = [mlp.logits]
7878
self.runtime_logprobs = [tf.nn.log_softmax(mlp.logits)]
7979

80-
tf.scalar_summary(
80+
tf.summary.scalar(
8181
'val_optimization_cost', self.cost,
8282
collections=["summary_val"])
83-
tf.scalar_summary(
83+
tf.summary.scalar(
8484
'train_optimization_cost',
8585
self.cost, collections=["summary_train"])
8686
# pylint: enable=too-many-arguments

neuralmonkey/decoders/sequence_labeler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def logits(self) -> tf.Tensor:
8686
biases = tf.get_variable(
8787
name="state_to_word_b",
8888
shape=[vocabulary_size],
89-
initializer=tf.zeros_initializer)
89+
initializer=tf.zeros_initializer())
9090

9191
weights_direct = tf.get_variable(
9292
name="emb_to_word_W",

neuralmonkey/decoders/word_alignment_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ def __init__(self,
3737
_, self.train_loss = self._make_decoder(runtime_mode=False)
3838
self.decoded, self.runtime_loss = self._make_decoder(runtime_mode=True)
3939

40-
tf.scalar_summary("alignment_train_xent", self.train_loss,
40+
tf.summary.scalar("alignment_train_xent", self.train_loss,
4141
collections=["summary_train"])
4242

4343
def _make_decoder(self, runtime_mode=False):
4444
attn_obj = self.decoder.get_attention_object(self.encoder,
4545
runtime_mode)
4646

47-
alignment_logits = tf.pack(attn_obj.logits_in_time,
48-
name="alignment_logits")
47+
alignment_logits = tf.stack(attn_obj.logits_in_time,
48+
name="alignment_logits")
4949

5050
if runtime_mode:
5151
# make batch_size the first dimension
@@ -56,7 +56,7 @@ def _make_decoder(self, runtime_mode=False):
5656
alignment = None
5757

5858
xent = tf.nn.softmax_cross_entropy_with_logits(
59-
alignment_logits, self.alignment_target)
59+
labels=self.alignment_target, logits=alignment_logits)
6060
loss = tf.reduce_sum(xent * self.decoder.train_padding)
6161

6262
return alignment, loss

neuralmonkey/encoders/factored_encoder.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from neuralmonkey.model.model_part import ModelPart
88
from neuralmonkey.encoders.attentive import Attentive
99
from neuralmonkey.logging import log
10-
from neuralmonkey.nn.bidirectional_rnn_layer import BidirectionalRNNLayer
1110
from neuralmonkey.vocabulary import Vocabulary
1211

1312

@@ -72,7 +71,7 @@ def _attention_tensor(self):
7271

7372
def _get_rnn_cell(self):
7473
"""Return the RNN cell for the encoder"""
75-
return tf.nn.rnn_cell.GRUCell(self.rnn_size)
74+
return tf.contrib.rnn.GRUCell(self.rnn_size)
7675

7776
def _get_birnn_cells(self):
7877
"""Return forward and backward RNN cells for the encoder"""
@@ -138,25 +137,25 @@ def _create_encoder_graph(self):
138137
# factors is a 2D list of embeddings of dims [factor-type, time-step]
139138
# by doing zip(*factors), we get a list of (factor-type) embedding
140139
# tuples indexed by the time step
141-
concatenated_factors = [tf.concat(1, related_factors)
140+
concatenated_factors = [tf.concat(related_factors, 1)
142141
for related_factors in zip(*factors)]
143142
assert_shape(concatenated_factors[0],
144143
[None, sum(self.embedding_sizes)])
145144
forward_gru, backward_gru = self._get_birnn_cells()
146145

147-
bidi_layer = BidirectionalRNNLayer(forward_gru, backward_gru,
148-
concatenated_factors,
149-
sentence_lengths)
146+
stacked_factors = tf.stack(concatenated_factors, 1)
150147

151-
self.outputs_bidi = bidi_layer.outputs_bidi
152-
self.encoded = bidi_layer.encoded
148+
self.outputs_bidi, encoded_tup = tf.nn.bidirectional_dynamic_rnn(
149+
forward_gru, backward_gru, stacked_factors,
150+
sentence_lengths, dtype=tf.float32)
153151

154-
self.__attention_tensor = tf.concat(1, [tf.expand_dims(o, 1)
155-
for o in self.outputs_bidi])
152+
self.encoded = tf.concat(encoded_tup, 1)
153+
154+
self.__attention_tensor = tf.concat(self.outputs_bidi, 2)
156155
self.__attention_tensor = tf.nn.dropout(self.__attention_tensor,
157156
self.dropout_placeholder)
158157
self.__attention_mask = tf.concat(
159-
1, [tf.expand_dims(w, 1) for w in self.padding_weights])
158+
[tf.expand_dims(w, 1) for w in self.padding_weights], 1)
160159

161160
# pylint: disable=too-many-locals
162161
def feed_dict(self, dataset, train=False):

0 commit comments

Comments
 (0)