Skip to content

Commit b384686

Browse files
authored
Merge pull request #781 from ufal/tf-data-2
Towards TF dataset, part III
2 parents 2c71059 + e3f0f68 commit b384686

File tree

94 files changed

+828
-1223
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+828
-1223
lines changed

.travis.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@ env:
1414
- TEST_SUITE=mypy
1515

1616
python:
17-
#- "2.7"
18-
#- "3.4"
19-
- "3.5"
20-
#- "3.5-dev" # 3.5 development branch
21-
#- "nightly" # currently points to 3.6-dev
17+
- "3.6"
2218

2319
# commands to install dependencies
2420
before_install:

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ python_speech_features
1010
pygments
1111
typeguard
1212
sacrebleu
13-
tensorflow>=1.10.0,<1.11
13+
tensorflow>=1.12.0,<1.13

neuralmonkey/checking.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,64 +4,14 @@
44
constructing the computational graph.
55
"""
66

7-
8-
from typing import List, Optional, Iterable
9-
7+
from typing import List, Optional
108
import tensorflow as tf
119

12-
from neuralmonkey.logging import log, debug
13-
from neuralmonkey.dataset import Dataset
14-
from neuralmonkey.runners.base_runner import BaseRunner
15-
1610

1711
class CheckingException(Exception):
1812
pass
1913

2014

21-
def check_dataset_and_coders(dataset: Dataset,
22-
runners: Iterable[BaseRunner]) -> None:
23-
# pylint: disable=protected-access
24-
25-
data_list = []
26-
for runner in runners:
27-
for c in runner.feedables:
28-
if hasattr(c, "data_id"):
29-
data_list.append((getattr(c, "data_id"), c))
30-
elif hasattr(c, "data_ids"):
31-
data_list.extend([(d, c) for d in getattr(c, "data_ids")])
32-
elif hasattr(c, "input_sequence"):
33-
inpseq = getattr(c, "input_sequence")
34-
if hasattr(inpseq, "data_id"):
35-
data_list.append((getattr(inpseq, "data_id"), c))
36-
elif hasattr(inpseq, "data_ids"):
37-
data_list.extend(
38-
[(d, c) for d in getattr(inpseq, "data_ids")])
39-
else:
40-
log("Input sequence: {} does not have a data attribute"
41-
.format(str(inpseq)))
42-
else:
43-
log(("Coder: {} has neither an input sequence attribute nor a "
44-
"a data attribute.").format(c))
45-
46-
debug("Found series: {}".format(str(data_list)), "checking")
47-
missing = []
48-
49-
for (serie, coder) in data_list:
50-
if serie not in dataset:
51-
log("dataset {} does not have serie {}".format(
52-
dataset.name, serie))
53-
missing.append((coder, serie))
54-
55-
if missing:
56-
formated = ["{} ({}, {}.{})" .format(serie, str(cod),
57-
cod.__class__.__module__,
58-
cod.__class__.__name__)
59-
for cod, serie in missing]
60-
61-
raise CheckingException("Dataset '{}' is mising series {}:"
62-
.format(dataset.name, ", ".join(formated)))
63-
64-
6515
def assert_shape(tensor: tf.Tensor,
6616
expected_shape: List[Optional[int]]) -> None:
6717
"""Check shape of a tensor.

neuralmonkey/checkpython.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22

3-
if sys.version_info[0] < 3 or sys.version_info[1] < 5:
3+
if sys.version_info[0] < 3 or sys.version_info[1] < 6:
44
print("Error:", file=sys.stderr)
5-
print("Neural Monkey must use Python >= 3.5", file=sys.stderr)
5+
print("Neural Monkey must use Python >= 3.6", file=sys.stderr)
66
print("Your Python is", sys.version, sys.executable, file=sys.stderr)
77
sys.exit(1)

neuralmonkey/config/normalize.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import numpy as np
1616

17-
from neuralmonkey.dataset import BatchingScheme
1817
from neuralmonkey.logging import warn
1918
from neuralmonkey.tf_manager import get_default_tf_manager
2019
from neuralmonkey.trainers.delayed_update_trainer import DelayedUpdateTrainer
@@ -34,25 +33,6 @@ def normalize_configuration(cfg: Namespace, train_mode: bool) -> None:
3433
if cfg.tf_manager is None:
3534
cfg.tf_manager = get_default_tf_manager()
3635

37-
if (cfg.batch_size is None) == (cfg.batching_scheme is None):
38-
raise ValueError("You must specify either batch_size or "
39-
"batching_scheme (not both).")
40-
41-
if cfg.batch_size is not None:
42-
assert cfg.batching_scheme is None
43-
cfg.batching_scheme = BatchingScheme(batch_size=cfg.batch_size)
44-
else:
45-
assert cfg.batching_scheme is not None
46-
cfg.batch_size = cfg.batching_scheme.batch_size
47-
48-
if cfg.runners_batch_size is None:
49-
cfg.runners_batch_size = cfg.batching_scheme.batch_size
50-
51-
cfg.runners_batching_scheme = BatchingScheme(
52-
batch_size=cfg.runners_batch_size,
53-
token_level_batching=cfg.batching_scheme.token_level_batching,
54-
use_leftover_buckets=True)
55-
5636
cfg.evaluation = [(e[0], e[0], e[1]) if len(e) == 2 else e
5737
for e in cfg.evaluation]
5838

neuralmonkey/config/parsing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,13 @@ def _parse_class_name(string: str, vars_dict: VarsDict) -> ClassSymbol:
150150

151151

152152
def _parse_value(string: str, vars_dict: VarsDict) -> Any:
153-
"""Parse the value recursively according to the Nerualmonkey grammar.
153+
"""Parse the value recursively according to the Nerual Monkey grammar.
154154
155155
Arguments:
156156
string: the string to be parsed
157157
vars_dict: a dictionary of variables for substitution
158158
"""
159+
string = string.strip()
159160

160161
if string in CONSTANTS:
161162
return CONSTANTS[string]

0 commit comments

Comments
 (0)