6
6
from neuralmonkey .vocabulary import Vocabulary
7
7
from neuralmonkey .model .model_part import ModelPart , FeedDict
8
8
from neuralmonkey .nn .mlp import MultilayerPerceptron
9
-
10
-
11
- # pylint: disable=too-many-instance-attributes
9
+ from neuralmonkey .decorators import tensor
12
10
13
11
14
12
class SequenceClassifier (ModelPart ):
@@ -53,36 +51,62 @@ def __init__(self,
53
51
self .dropout_keep_prob = dropout_keep_prob
54
52
self .max_output_len = 1
55
53
56
- with self .use_scope ():
57
- self .train_mode = tf .placeholder (tf .bool , name = "train_mode" )
58
- self .learning_step = tf .get_variable (
59
- "learning_step" , [], trainable = False ,
60
- initializer = tf .constant_initializer (0 ))
61
-
62
- self .gt_inputs = [tf .placeholder (
63
- tf .int32 , shape = [None ], name = "targets" )]
64
- mlp_input = tf .concat ([enc .encoded for enc in encoders ], 1 )
65
- mlp = MultilayerPerceptron (
66
- mlp_input , layers , self .dropout_keep_prob , len (vocabulary ),
67
- activation_fn = self .activation_fn , train_mode = self .train_mode )
68
-
69
- self .loss_with_gt_ins = tf .reduce_mean (
70
- tf .nn .sparse_softmax_cross_entropy_with_logits (
71
- logits = mlp .logits , labels = self .gt_inputs [0 ]))
72
- self .loss_with_decoded_ins = self .loss_with_gt_ins
73
- self .cost = self .loss_with_gt_ins
74
-
75
- self .decoded_seq = [mlp .classification ]
76
- self .decoded_logits = [mlp .logits ]
77
- self .runtime_logprobs = [tf .nn .log_softmax (mlp .logits )]
78
-
79
- tf .summary .scalar (
80
- 'val_optimization_cost' , self .cost ,
81
- collections = ["summary_val" ])
82
- tf .summary .scalar (
83
- 'train_optimization_cost' ,
84
- self .cost , collections = ["summary_train" ])
85
- # pylint: enable=too-many-arguments
54
+ tf .summary .scalar (
55
+ 'train_optimization_cost' ,
56
+ self .cost , collections = ["summary_train" ])
57
+ # pylint: enable=too-many-arguments
58
+
59
+ # pylint: disable=no-self-use
60
+ @tensor
61
+ def train_mode (self ) -> tf .Tensor :
62
+ return tf .placeholder (tf .bool , name = "train_mode" )
63
+
64
+ @tensor
65
+ def gt_inputs (self ) -> List [tf .Tensor ]:
66
+ return [tf .placeholder (tf .int32 , shape = [None ], name = "targets" )]
67
+ # pylint: enable=no-self-use
68
+
69
+ @tensor
70
+ def _mlp (self ) -> MultilayerPerceptron :
71
+ mlp_input = tf .concat ([enc .encoded for enc in self .encoders ], 1 )
72
+ return MultilayerPerceptron (
73
+ mlp_input , self .layers ,
74
+ self .dropout_keep_prob , len (self .vocabulary ),
75
+ activation_fn = self .activation_fn , train_mode = self .train_mode )
76
+
77
+ @tensor
78
+ def loss_with_gt_ins (self ) -> tf .Tensor :
79
+ # pylint: disable=no-member,unsubscriptable-object
80
+ return tf .reduce_mean (
81
+ tf .nn .sparse_softmax_cross_entropy_with_logits (
82
+ logits = self ._mlp .logits , labels = self .gt_inputs [0 ]))
83
+ # pylint: enable=no-member,unsubscriptable-object
84
+
85
+ @property
86
+ def loss_with_decoded_ins (self ) -> tf .Tensor :
87
+ return self .loss_with_gt_ins
88
+
89
+ @property
90
+ def cost (self ) -> tf .Tensor :
91
+ return self .loss_with_gt_ins
92
+
93
+ @tensor
94
+ def decoded_seq (self ) -> List [tf .Tensor ]:
95
+ # pylint: disable=no-member
96
+ return [self ._mlp .classification ]
97
+ # pylint: enable=no-member
98
+
99
+ @tensor
100
+ def decoded_logits (self ) -> List [tf .Tensor ]:
101
+ # pylint: disable=no-member
102
+ return [self ._mlp .logits ]
103
+ # pylint: enable=no-member
104
+
105
+ @tensor
106
+ def runtime_logprobs (self ) -> List [tf .Tensor ]:
107
+ # pylint: disable=no-member
108
+ return [tf .nn .log_softmax (self ._mlp .logits )]
109
+ # pylint: enable=no-member
86
110
87
111
@property
88
112
def train_loss (self ):
@@ -108,7 +132,9 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
108
132
label_tensors , _ = self .vocabulary .sentences_to_tensor (
109
133
sentences_list , self .max_output_len )
110
134
135
+ # pylint: disable=unsubscriptable-object
111
136
fd [self .gt_inputs [0 ]] = label_tensors [0 ]
137
+ # pylint: enable=unsubscriptable-object
112
138
113
139
fd [self .train_mode ] = train
114
140
0 commit comments