Skip to content

Commit b57d4d3

Browse files
authored
Merge pull request #773 from ufal/model_refactor
Model refactor
2 parents be667fd + 89900f8 commit b57d4d3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+529
-365
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ max-branches=12
345345
max-statements=50
346346

347347
# Maximum number of parents for a class (see R0901).
348-
max-parents=7
348+
max-parents=10
349349

350350
# Maximum number of attributes for a class (see R0902).
351351
max-attributes=12

neuralmonkey/attention/base_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141

4242
import tensorflow as tf
4343

44-
from neuralmonkey.model.stateful import TemporalStateful, SpatialStateful
45-
from neuralmonkey.model.model_part import ModelPart, InitializerSpecs
4644
from neuralmonkey.attention.namedtuples import AttentionLoopState
45+
from neuralmonkey.model.model_part import ModelPart
46+
from neuralmonkey.model.parameterized import InitializerSpecs
47+
from neuralmonkey.model.stateful import TemporalStateful, SpatialStateful
4748

4849
# pylint: disable=invalid-name
4950
Attendable = Union[TemporalStateful, SpatialStateful]

neuralmonkey/attention/combination.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
get_attention_states, get_attention_mask, Attendable)
2424
from neuralmonkey.attention.namedtuples import HierarchicalLoopState
2525
from neuralmonkey.checking import assert_shape
26-
from neuralmonkey.model.model_part import InitializerSpecs, ModelPart
26+
from neuralmonkey.model.model_part import ModelPart
27+
from neuralmonkey.model.parameterized import InitializerSpecs
2728
from neuralmonkey.tf_utils import get_variable
2829

2930

neuralmonkey/attention/coverage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
from neuralmonkey.attention.base_attention import Attendable
1212
from neuralmonkey.attention.feed_forward import Attention
13-
from neuralmonkey.model.model_part import InitializerSpecs, ModelPart
13+
from neuralmonkey.model.model_part import ModelPart
14+
from neuralmonkey.model.parameterized import InitializerSpecs
1415

1516

1617
class CoverageAttention(Attention):

neuralmonkey/attention/feed_forward.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
BaseAttention, AttentionLoopState, empty_attention_loop_state,
1414
get_attention_states, get_attention_mask, Attendable)
1515
from neuralmonkey.decorators import tensor
16-
from neuralmonkey.nn.utils import dropout
1716
from neuralmonkey.logging import log
18-
from neuralmonkey.model.model_part import InitializerSpecs, ModelPart
17+
from neuralmonkey.model.model_part import ModelPart
18+
from neuralmonkey.model.parameterized import InitializerSpecs
19+
from neuralmonkey.nn.utils import dropout
1920
from neuralmonkey.tf_utils import get_variable
2021

2122

neuralmonkey/attention/scaled_dot_product.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from typeguard import check_argument_types
1414

1515
from neuralmonkey.nn.utils import dropout
16-
from neuralmonkey.model.model_part import InitializerSpecs, ModelPart
16+
from neuralmonkey.model.model_part import ModelPart
17+
from neuralmonkey.model.parameterized import InitializerSpecs
1718
from neuralmonkey.attention.base_attention import (
1819
BaseAttention, Attendable, get_attention_states, get_attention_mask)
1920
from neuralmonkey.attention.namedtuples import MultiHeadLoopState

neuralmonkey/attention/stateful_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
BaseAttention, AttentionLoopState, empty_attention_loop_state)
88
from neuralmonkey.model.stateful import Stateful
99
from neuralmonkey.decorators import tensor
10-
from neuralmonkey.model.model_part import InitializerSpecs, ModelPart
10+
from neuralmonkey.model.model_part import ModelPart
11+
from neuralmonkey.model.parameterized import InitializerSpecs
1112

1213

1314
class StatefulContext(BaseAttention):

neuralmonkey/checking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def check_dataset_and_coders(dataset: Dataset,
2424

2525
data_list = []
2626
for runner in runners:
27-
for c in runner.all_coders:
27+
for c in runner.feedables:
2828
if hasattr(c, "data_id"):
2929
data_list.append((getattr(c, "data_id"), c))
3030
elif hasattr(c, "data_ids"):
@@ -53,7 +53,7 @@ def check_dataset_and_coders(dataset: Dataset,
5353
missing.append((coder, serie))
5454

5555
if missing:
56-
formated = ["{} ({}, {}.{})" .format(serie, cod.name,
56+
formated = ["{} ({}, {}.{})" .format(serie, str(cod),
5757
cod.__class__.__module__,
5858
cod.__class__.__name__)
5959
for cod, serie in missing]

neuralmonkey/decoders/autoregressive.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
from neuralmonkey.dataset import Dataset
1414
from neuralmonkey.decorators import tensor
15-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
15+
from neuralmonkey.model.feedable import FeedDict
16+
from neuralmonkey.model.parameterized import InitializerSpecs
17+
from neuralmonkey.model.model_part import ModelPart
1618
from neuralmonkey.logging import log, warn
1719
from neuralmonkey.model.sequence import EmbeddedSequence
1820
from neuralmonkey.nn.utils import dropout

neuralmonkey/decoders/classifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from neuralmonkey.dataset import Dataset
77
from neuralmonkey.vocabulary import Vocabulary
8-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
8+
from neuralmonkey.model.feedable import FeedDict
9+
from neuralmonkey.model.parameterized import InitializerSpecs
10+
from neuralmonkey.model.model_part import ModelPart
911
from neuralmonkey.model.stateful import Stateful
1012
from neuralmonkey.nn.mlp import MultilayerPerceptron
1113
from neuralmonkey.decorators import tensor

neuralmonkey/decoders/ctc_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from neuralmonkey.dataset import Dataset
88
from neuralmonkey.decorators import tensor
99
from neuralmonkey.logging import log
10-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
10+
from neuralmonkey.model.feedable import FeedDict
11+
from neuralmonkey.model.parameterized import InitializerSpecs
12+
from neuralmonkey.model.model_part import ModelPart
1113
from neuralmonkey.model.stateful import TemporalStateful
1214
from neuralmonkey.tf_utils import get_variable
1315
from neuralmonkey.vocabulary import Vocabulary, END_TOKEN

neuralmonkey/decoders/decoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
Vocabulary, END_TOKEN_INDEX, PAD_TOKEN_INDEX)
1111
from neuralmonkey.model.sequence import EmbeddedSequence
1212
from neuralmonkey.model.stateful import Stateful
13-
from neuralmonkey.model.model_part import ModelPart, InitializerSpecs
13+
from neuralmonkey.model.parameterized import InitializerSpecs
14+
from neuralmonkey.model.model_part import ModelPart
1415
from neuralmonkey.logging import log
1516
from neuralmonkey.nn.ortho_gru_cell import OrthoGRUCell, NematusGRUCell
1617
from neuralmonkey.nn.utils import dropout

neuralmonkey/decoders/sequence_labeler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from typeguard import check_argument_types
55

66
from neuralmonkey.dataset import Dataset
7-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
7+
from neuralmonkey.model.feedable import FeedDict
8+
from neuralmonkey.model.parameterized import InitializerSpecs
9+
from neuralmonkey.model.model_part import ModelPart
810
from neuralmonkey.encoders.recurrent import RecurrentEncoder
911
from neuralmonkey.encoders.facebook_conv import SentenceEncoder
1012
from neuralmonkey.vocabulary import Vocabulary

neuralmonkey/decoders/sequence_regressor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from neuralmonkey.nn.projection import multilayer_projection
77
from neuralmonkey.dataset import Dataset
8-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
8+
from neuralmonkey.model.feedable import FeedDict
9+
from neuralmonkey.model.parameterized import InitializerSpecs
10+
from neuralmonkey.model.model_part import ModelPart
911
from neuralmonkey.model.stateful import Stateful
1012
from neuralmonkey.decorators import tensor
1113

neuralmonkey/decoders/transformer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
AutoregressiveDecoder, LoopState, DecoderFeedables)
2121
from neuralmonkey.encoders.transformer import (
2222
TransformerLayer, position_signal)
23-
from neuralmonkey.model.sequence import EmbeddedSequence
2423
from neuralmonkey.logging import log, warn
25-
from neuralmonkey.model.model_part import ModelPart, InitializerSpecs
24+
from neuralmonkey.model.sequence import EmbeddedSequence
25+
from neuralmonkey.model.parameterized import InitializerSpecs
26+
from neuralmonkey.model.model_part import ModelPart
2627
from neuralmonkey.nn.utils import dropout
2728
from neuralmonkey.vocabulary import (
2829
Vocabulary, PAD_TOKEN_INDEX, END_TOKEN_INDEX)

neuralmonkey/decoders/word_alignment_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from neuralmonkey.encoders.recurrent import RecurrentEncoder
99
from neuralmonkey.decoders.decoder import Decoder
1010
from neuralmonkey.logging import warn
11-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
11+
from neuralmonkey.model.feedable import FeedDict
12+
from neuralmonkey.model.parameterized import InitializerSpecs
13+
from neuralmonkey.model.model_part import ModelPart
1214
from neuralmonkey.model.sequence import Sequence
1315
from neuralmonkey.decorators import tensor
1416

neuralmonkey/encoders/attentive.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from typing import Set, cast
2-
31
import tensorflow as tf
42
from typeguard import check_argument_types
53

64
from neuralmonkey.model.stateful import TemporalStatefulWithOutput
7-
from neuralmonkey.model.model_part import ModelPart, InitializerSpecs
5+
from neuralmonkey.model.parameterized import InitializerSpecs
6+
from neuralmonkey.model.model_part import ModelPart
87
from neuralmonkey.nn.utils import dropout
98
from neuralmonkey.decorators import tensor
109
from neuralmonkey.attention.base_attention import (
@@ -102,13 +101,3 @@ def output(self) -> tf.Tensor:
102101
name="output_projection")
103102

104103
return output
105-
106-
def get_dependencies(self) -> Set[ModelPart]:
107-
deps = ModelPart.get_dependencies(self)
108-
109-
# feed only if needed
110-
if isinstance(self.input_sequence, ModelPart):
111-
feedable = cast(ModelPart, self.input_sequence)
112-
deps |= feedable.get_dependencies()
113-
114-
return deps

neuralmonkey/encoders/cnn_encoder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""CNN for image processing."""
22

3-
from typing import cast, Callable, List, Tuple, Set, Union
3+
from typing import cast, Callable, List, Tuple, Union
44
from typeguard import check_argument_types
55

66
import numpy as np
77
import tensorflow as tf
88

99
from neuralmonkey.dataset import Dataset
1010
from neuralmonkey.decorators import tensor
11-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
11+
from neuralmonkey.model.feedable import FeedDict
12+
from neuralmonkey.model.parameterized import InitializerSpecs
13+
from neuralmonkey.model.model_part import ModelPart
1214
from neuralmonkey.model.stateful import (SpatialStatefulWithOutput,
1315
TemporalStatefulWithOutput)
1416
from neuralmonkey.nn.projection import multilayer_projection
@@ -349,6 +351,6 @@ def temporal_mask(self) -> tf.Tensor:
349351
summed = tf.reduce_sum(mask, axis=1)
350352
return tf.to_float(tf.greater(summed, 0))
351353

352-
def get_dependencies(self) -> Set["ModelPart"]:
353-
"""Collect recusively all encoders and decoders."""
354-
return self._cnn.get_dependencies().union([self])
354+
@property
355+
def dependencies(self) -> List[str]:
356+
return super().dependencies + ["_cnn"]

neuralmonkey/encoders/facebook_conv.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import numpy as np
88
from typeguard import check_argument_types
99

10-
from neuralmonkey.model.model_part import ModelPart, InitializerSpecs
11-
from neuralmonkey.logging import log
1210
from neuralmonkey.decorators import tensor
13-
from neuralmonkey.nn.projection import glu
11+
from neuralmonkey.logging import log
12+
from neuralmonkey.model.model_part import ModelPart
13+
from neuralmonkey.model.parameterized import InitializerSpecs
1414
from neuralmonkey.model.sequence import EmbeddedSequence
1515
from neuralmonkey.model.stateful import TemporalStatefulWithOutput
16+
from neuralmonkey.nn.projection import glu
1617
from neuralmonkey.tf_utils import get_variable
1718

1819

neuralmonkey/encoders/imagenet_encoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
from neuralmonkey.dataset import Dataset
1717
from neuralmonkey.decorators import tensor
18-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
18+
from neuralmonkey.model.feedable import FeedDict
19+
from neuralmonkey.model.model_part import ModelPart
20+
from neuralmonkey.model.parameterized import InitializerSpecs
1921
from neuralmonkey.model.stateful import SpatialStatefulWithOutput
2022

2123

neuralmonkey/encoders/numpy_stateful_filler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from neuralmonkey.dataset import Dataset
77
from neuralmonkey.decorators import tensor
8-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
8+
from neuralmonkey.model.feedable import FeedDict
9+
from neuralmonkey.model.parameterized import InitializerSpecs
10+
from neuralmonkey.model.model_part import ModelPart
911
from neuralmonkey.model.stateful import Stateful, SpatialStatefulWithOutput
1012

1113

neuralmonkey/encoders/pooling.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from typing import Set, cast
2-
31
import tensorflow as tf
42
from typeguard import check_argument_types
53

64
from neuralmonkey.model.stateful import Stateful, TemporalStateful
7-
from neuralmonkey.model.model_part import ModelPart, InitializerSpecs
5+
from neuralmonkey.model.parameterized import InitializerSpecs
6+
from neuralmonkey.model.model_part import ModelPart
87
from neuralmonkey.decorators import tensor
98

109

1110
# pylint: disable=abstract-method
11+
# Pylint bug: https://github.com/PyCQA/pylint/issues/179
1212
class SequencePooling(ModelPart, Stateful):
1313
"""An abstract pooling layer over a sequence."""
1414

@@ -31,16 +31,6 @@ def __init__(self,
3131
self.input_sequence.temporal_mask, -1)
3232
self._masked_input = (
3333
self.input_sequence.temporal_states * self._input_mask)
34-
35-
def get_dependencies(self) -> Set[ModelPart]:
36-
deps = ModelPart.get_dependencies(self)
37-
38-
# feed only if needed
39-
if isinstance(self.input_sequence, ModelPart):
40-
feedable = cast(ModelPart, self.input_sequence)
41-
deps |= feedable.get_dependencies()
42-
43-
return deps
4434
# pylint: enable=abstract-method
4535

4636

neuralmonkey/encoders/raw_rnn_encoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from neuralmonkey.encoders.recurrent import (
99
RNNSpecTuple, _make_rnn_spec, _make_rnn_cell)
1010
# pylint: enable=protected-access
11-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
11+
from neuralmonkey.model.feedable import FeedDict
12+
from neuralmonkey.model.parameterized import InitializerSpecs
13+
from neuralmonkey.model.model_part import ModelPart
1214
from neuralmonkey.model.stateful import TemporalStatefulWithOutput
1315
from neuralmonkey.logging import log
1416
from neuralmonkey.nn.utils import dropout

neuralmonkey/encoders/recurrent.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Tuple, List, Union, Callable, cast, Set, NamedTuple
1+
from typing import Tuple, List, Union, Callable, NamedTuple
22

33
import tensorflow as tf
44
from typeguard import check_argument_types
55

66
from neuralmonkey.model.stateful import (
77
TemporalStatefulWithOutput, TemporalStateful)
8-
from neuralmonkey.model.model_part import ModelPart, InitializerSpecs
8+
from neuralmonkey.model.parameterized import InitializerSpecs
9+
from neuralmonkey.model.model_part import ModelPart
910
from neuralmonkey.logging import warn
1011
from neuralmonkey.nn.ortho_gru_cell import OrthoGRUCell, NematusGRUCell
1112
from neuralmonkey.nn.utils import dropout
@@ -194,16 +195,6 @@ def output(self) -> tf.Tensor:
194195
return self.rnn[1]
195196
# pylint: enable=unsubscriptable-object
196197

197-
def get_dependencies(self) -> Set[ModelPart]:
198-
"""Collect recusively all encoders and decoders."""
199-
deps = ModelPart.get_dependencies(self)
200-
201-
# feed only if needed
202-
if isinstance(self.input_sequence, ModelPart):
203-
feedable = cast(ModelPart, self.input_sequence)
204-
deps = deps.union(feedable.get_dependencies())
205-
return deps
206-
207198

208199
class SentenceEncoder(RecurrentEncoder):
209200
# pylint: disable=too-many-arguments,too-many-locals

neuralmonkey/encoders/sentence_cnn_encoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from typeguard import check_argument_types
77

88
from neuralmonkey.encoders.recurrent import RNNCellTuple
9-
from neuralmonkey.model.model_part import ModelPart, InitializerSpecs
9+
from neuralmonkey.model.parameterized import InitializerSpecs
10+
from neuralmonkey.model.model_part import ModelPart
1011
from neuralmonkey.model.sequence import Sequence
1112
from neuralmonkey.model.stateful import TemporalStatefulWithOutput
1213
from neuralmonkey.nn.noisy_gru_cell import NoisyGRUCell

neuralmonkey/encoders/sequence_cnn_encoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
from neuralmonkey.dataset import Dataset
99
from neuralmonkey.decorators import tensor
10-
from neuralmonkey.model.model_part import ModelPart, FeedDict, InitializerSpecs
10+
from neuralmonkey.model.feedable import FeedDict
11+
from neuralmonkey.model.parameterized import InitializerSpecs
12+
from neuralmonkey.model.model_part import ModelPart
1113
from neuralmonkey.model.stateful import Stateful
1214
from neuralmonkey.nn.utils import dropout
1315
from neuralmonkey.vocabulary import Vocabulary

0 commit comments

Comments
 (0)