Skip to content

Commit b7f53b3

Browse files
committed
pep8
1 parent cf60956 commit b7f53b3

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

keras_wrapper/cnn_model.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def setOutputsMapping(self, outputsMapping, acc_output=None):
319319
the desired output order (in this case only one value can be provided).
320320
If it is Model then keys must be str.
321321
:param acc_output: name of the model's output that will be used for calculating
322-
the accuracy of the model (only needed for Graph models)
322+
the accuracy of the model (only needed for Model models)
323323
"""
324324
if isinstance(self.model, Sequential) and len(list(outputsMapping)) > 1:
325325
raise Exception("When using Sequential models only one output can be provided in outputsMapping")
@@ -404,9 +404,12 @@ def setOptimizer(self, lr=None, momentum=None, loss='categorical_crossentropy',
404404
if not self.silence:
405405
logger.info("Compiling model...")
406406

407-
# compile differently depending if our model is 'Sequential', 'Model' or 'Graph'
407+
# compile differently depending if our model is 'Sequential', 'Model'
408408
if isinstance(self.model, Sequential) or isinstance(self.model, Model):
409-
self.model.compile(optimizer=optimizer, metrics=metrics, loss=loss, loss_weights=loss_weights,
409+
self.model.compile(optimizer=optimizer,
410+
metrics=metrics,
411+
loss=loss,
412+
loss_weights=loss_weights,
410413
sample_weight_mode=sample_weight_mode)
411414
else:
412415
raise NotImplementedError()
@@ -576,8 +579,7 @@ def trainNetFromSamples(self, x, y, parameters=None, class_weight=None, sample_w
576579
:param parameters:
577580
:param class_weight:
578581
:param sample_weight:
579-
:param out_name: name of the output node that will be used to evaluate the network accuracy.
580-
Only applicable to Graph models.
582+
:param out_name: name of the output node that will be used to evaluate the network accuracy. Only applicable to Models.
581583
582584
The input 'parameters' is a dict() which may contain the following (optional) training parameters:
583585
#### Visualization parameters
@@ -884,7 +886,6 @@ def testNet(self, ds, parameters, out_name=None):
884886
for name, o in zip(self.model.metrics_names, out):
885887
logger.info('test ' + name + ': %0.8s' % o)
886888

887-
888889
def testNetSamples(self, X, batch_size=50):
889890
"""
890891
Applies a forward pass on the samples provided and returns the predicted classes and probabilities.
@@ -914,7 +915,7 @@ def testOnBatch(self, X, Y, accuracy=True, out_name=None):
914915
return loss, score, top_score, n_samples
915916
return loss, n_samples
916917
else:
917-
[data, last_output] = self._prepareGraphData(X, Y)
918+
[data, last_output] = self._prepareModelData(X, Y)
918919
loss = self.model.test_on_batch(data)
919920
loss = loss[0]
920921
if accuracy:
@@ -1594,10 +1595,10 @@ def predictOnBatch(self, X, in_name=None, out_name=None, expand=False):
15941595
predictions = self.model.predict_on_batch(X)
15951596

15961597
# Select output if indicated
1597-
if isinstance(self.model, Model): # Graph
1598+
if isinstance(self.model, Model):
15981599
if out_name:
15991600
predictions = predictions[out_name]
1600-
elif isinstance(self.model, Sequential): # Sequential
1601+
elif isinstance(self.model, Sequential):
16011602
predictions = predictions[0]
16021603

16031604
return predictions
@@ -1861,7 +1862,7 @@ def decode_predictions_one_hot(preds, index2word, verbose=0):
18611862

18621863
def prepareData(self, X_batch, Y_batch=None):
18631864
"""
1864-
Prepares the data for the model, depending on its type (Sequential, Model, Graph).
1865+
Prepares the data for the model, depending on its type (Sequential, Model).
18651866
:param X_batch: Batch of input data.
18661867
:param Y_batch: Batch output data.
18671868
:return: Prepared data.

keras_wrapper/models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ def __init__(self,
7171
inheritance=inheritance,
7272
)
7373

74-
75-
7674
def basic_model(self, nOutput, model_input):
7775
"""
7876
Builds a basic CNN model.

0 commit comments

Comments
 (0)