Skip to content

Commit 2ae01dc

Browse files
committed
Fix model saving/loading
1 parent b5bccbb commit 2ae01dc

File tree

5 files changed

+11
-17
lines changed

5 files changed

+11
-17
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ patterns.json
99
*.npy
1010
*.json
1111
*.h5
12+
*.keras
1213
*.png
1314
*.parquet
1415
regex101/

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ This will generate `preprocessed_train.txt` which contains all the feature vecto
3232

3333
`pipenv run python train.py`
3434

35-
The model architecture will be stored in `nn_model_sherlock.json` with the weights in `nn_model_sherlock.weights.h5`.
35+
The model architecture will be stored in `nn_model_sherlock.json` with the weights in `nn_model_sherlock.weights.keras`.
3636

3737
## Evaluation
3838

explain.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def update_sample(samples, N, sample):
4646
)
4747

4848
# Load the trained model
49-
model = tf.keras.models.model_from_json(open("nn_model_sherlock.json").read())
50-
model.load_weights("nn_model_sherlock.weights.h5")
49+
model = tf.keras.models.load_model(os.path.join(args.input_dir, "nn_model_sherlock.keras"))
5150

5251
# Produce a randomly sample of background from the training data
5352
background = []

test.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pyarrow.parquet import ParquetFile
88
from sklearn.metrics import classification_report
99
from sklearn.preprocessing import LabelEncoder
10-
from tensorflow.keras.models import model_from_json
10+
from tensorflow.keras.models import load_model
1111
from tqdm import tqdm
1212

1313
BATCH_SIZE = 1000
@@ -27,10 +27,7 @@
2727
# labels = le.transform(labels.values.ravel())
2828
num_examples = len(labels)
2929

30-
model = model_from_json(
31-
open(os.path.join(args.input_dir, "nn_model_sherlock.json"), "r").read()
32-
)
33-
model.load_weights(os.path.join(args.input_dir, "nn_model_sherlock.weights.h5"))
30+
model = load_model(os.path.join(args.input_dir, "nn_model_sherlock.keras"))
3431

3532
sys.stderr.write("Evaluating...\n")
3633
labels_pred = [""] * len(labels)

train.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,31 @@
5151
regex_model1 = BatchNormalization(axis=1)(regex_model_input)
5252
regex_model2 = Dense(
5353
1000,
54-
activation=tf.nn.relu,
54+
activation='relu',
5555
kernel_regularizer=tf.keras.regularizers.l2(0.0001),
5656
)(regex_model1)
5757
regex_model3 = Dropout(0.35)(regex_model2)
5858
regex_model4 = Dense(
5959
1000,
60-
activation=tf.nn.relu,
60+
activation='relu',
6161
kernel_regularizer=tf.keras.regularizers.l2(0.0001),
6262
)(regex_model3)
6363

6464
merged_model2 = BatchNormalization(axis=1)(regex_model4)
6565
merged_model3 = Dense(
6666
500,
67-
activation=tf.nn.relu,
67+
activation='relu',
6868
kernel_regularizer=tf.keras.regularizers.l2(0.0001),
6969
)(merged_model2)
7070
merged_model4 = Dropout(0.35)(merged_model3)
7171
merged_model5 = Dense(
7272
500,
73-
activation=tf.nn.relu,
73+
activation='relu',
7474
kernel_regularizer=tf.keras.regularizers.l2(0.0001),
7575
)(merged_model4)
7676
merged_model_output = Dense(
7777
len(le.classes_),
78-
activation=tf.nn.softmax,
78+
activation='softmax',
7979
kernel_regularizer=tf.keras.regularizers.l2(0.0001),
8080
)(merged_model5)
8181

@@ -86,9 +86,6 @@
8686
loss="categorical_crossentropy",
8787
metrics=["categorical_accuracy"],
8888
)
89-
open(os.path.join(args.output_dir, "nn_model_sherlock.json"), "w").write(
90-
model.to_json()
91-
)
9289

9390
preprocessed = open(os.path.join(args.input_dir, "preprocessed_train.txt"), "r")
9491
i = 0
@@ -115,4 +112,4 @@
115112
pbar.update(len(matrix))
116113

117114
# Save the trained model weights
118-
model.save_weights(os.path.join(args.output_dir, "nn_model_sherlock.weights.h5"))
115+
model.save(os.path.join(args.output_dir, "nn_model_sherlock.keras"))

0 commit comments

Comments
 (0)