Skip to content

Commit dd07051

Browse files
committed
Refactor: Replace print statements during training with logging
1 parent 63e6530 commit dd07051

File tree

7 files changed

+68
-15
lines changed

7 files changed

+68
-15
lines changed

ingredient_parser/en/preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ def _token_features(self, token: Token) -> dict[str, str | bool | int | float]:
10831083

10841084
return features
10851085

1086-
def sentence_features(self) -> list[dict[str, str | bool | int | float]]:
1086+
def sentence_features(self) -> list[dict[str, str | bool]]:
10871087
"""Return features for all tokens in sentence.
10881088
10891089
Returns

train.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import argparse
44
import json
5+
import logging
56
import os
7+
import sys
68
from random import randint
79

810
from train import (
@@ -24,6 +26,12 @@ def __call__(self, parser, namespace, values, option_strings):
2426
setattr(namespace, self.dest, json.loads(values))
2527

2628

29+
LOGGING_LEVEL = {
30+
0: logging.INFO,
31+
1: logging.DEBUG,
32+
}
33+
34+
2735
if __name__ == "__main__":
2836
parser = argparse.ArgumentParser(
2937
description="Train a CRF model to parse label token from recipe \
@@ -85,6 +93,13 @@ def __call__(self, parser, namespace, values, option_strings):
8593
action="store_true",
8694
help="Plot confusion matrix of token labels.",
8795
)
96+
train_parser.add_argument(
97+
"-v",
98+
help="Enable verbose output.",
99+
action="count",
100+
default=0,
101+
dest="verbose",
102+
)
88103

89104
multiple_parser_help = "Average CRF performance across multiple training cycles."
90105
multiple_parser = subparsers.add_parser("multiple", help=multiple_parser_help)
@@ -149,6 +164,13 @@ def __call__(self, parser, namespace, values, option_strings):
149164
type=int,
150165
help="Number of processes to spawn. Default to number of cpu cores.",
151166
)
167+
multiple_parser.add_argument(
168+
"-v",
169+
help="Enable verbose output.",
170+
action="count",
171+
default=0,
172+
dest="verbose",
173+
)
152174

153175
gridsearch_parser_help = (
154176
"Grid search over all combinations of model hyperparameters."
@@ -255,6 +277,13 @@ def __call__(self, parser, namespace, values, option_strings):
255277
action=ParseJsonArg,
256278
default=dict(),
257279
)
280+
gridsearch_parser.add_argument(
281+
"-v",
282+
help="Enable verbose output.",
283+
action="count",
284+
default=0,
285+
dest="verbose",
286+
)
258287

259288
featuresearch_parser_help = "Grid search over all sets of model features."
260289
featuresearch_parser = subparsers.add_parser(
@@ -311,6 +340,13 @@ def __call__(self, parser, namespace, values, option_strings):
311340
type=int,
312341
help="Seed value used for train/test split.",
313342
)
343+
featuresearch_parser.add_argument(
344+
"-v",
345+
help="Enable verbose output.",
346+
action="count",
347+
default=0,
348+
dest="verbose",
349+
)
314350

315351
utility_help = "Utilities to aid cleaning training data."
316352
utility_parser = subparsers.add_parser("utility", help=utility_help)
@@ -343,6 +379,12 @@ def __call__(self, parser, namespace, values, option_strings):
343379

344380
args = parser.parse_args()
345381

382+
logging.basicConfig(
383+
stream=sys.stdout,
384+
level=LOGGING_LEVEL[args.verbose],
385+
format="[%(levelname)s] (%(module)s) %(message)s",
386+
)
387+
346388
if args.command == "train":
347389
train_single(args)
348390
elif args.command == "multiple":

train/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"check_label_consistency",
88
"feature_search",
99
"grid_search",
10-
"train_embeddings",
1110
"train_multiple",
1211
"train_single",
1312
]

train/featuresearch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import argparse
44
import concurrent.futures as cf
5+
import logging
56
import os
67
import time
78
from datetime import timedelta
@@ -20,6 +21,8 @@
2021
load_datasets,
2122
)
2223

24+
logger = logging.getLogger(__name__)
25+
2326
DISCARDED_FEATURES = {
2427
0: [],
2528
1: [
@@ -195,8 +198,8 @@ def feature_search(args: argparse.Namespace):
195198
]
196199
argument_sets.append(arguments)
197200

198-
print(f"[INFO] Grid search over {len(argument_sets)} feature sets.")
199-
print(f"[INFO] {args.seed} is the random seed used for the train/test split.")
201+
logger.info(f"Grid search over {len(argument_sets)} feature sets.")
202+
logger.info(f"{args.seed} is the random seed used for the train/test split.")
200203

201204
eval_results = []
202205
with cf.ProcessPoolExecutor(max_workers=args.processes) as executor:

train/gridsearch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import argparse
44
import concurrent.futures as cf
5+
import logging
56
import os
67
import time
78
from datetime import timedelta
@@ -21,6 +22,8 @@
2122
load_datasets,
2223
)
2324

25+
logger = logging.getLogger(__name__)
26+
2427
# Valid parameter options for LBFGS training algorithm and expected types
2528
VALID_LBFGS_PARAMS = {
2629
"c1": (float, int),
@@ -483,8 +486,8 @@ def grid_search(args: argparse.Namespace):
483486

484487
arguments = generate_argument_sets(args)
485488

486-
print(f"[INFO] Grid search over {len(arguments)} hyperparameters combinations.")
487-
print(f"[INFO] {args.seed} is the random seed used for the train/test split.")
489+
logger.info(f"Grid search over {len(arguments)} hyperparameters combinations.")
490+
logger.info(f"{args.seed} is the random seed used for the train/test split.")
488491

489492
eval_results = []
490493
with cf.ProcessPoolExecutor(max_workers=args.processes) as executor:

train/train_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import argparse
44
import concurrent.futures as cf
55
import contextlib
6+
import logging
67
from pathlib import Path
78
from random import randint
89
from statistics import mean, stdev
@@ -23,6 +24,8 @@
2324
load_datasets,
2425
)
2526

27+
logger = logging.getLogger(__name__)
28+
2629

2730
def train_parser_model(
2831
vectors: DataVectors,
@@ -69,7 +72,7 @@ def train_parser_model(
6972
if seed is None:
7073
seed = randint(0, 1_000_000_000)
7174

72-
print(f"[INFO] {seed} is the random seed used for the train/test split.")
75+
logger.info(f"{seed} is the random seed used for the train/test split.")
7376

7477
# Split data into train and test sets
7578
# The stratify argument means that each dataset is represented proprtionally
@@ -96,10 +99,10 @@ def train_parser_model(
9699
stratify=vectors.source,
97100
random_state=seed,
98101
)
99-
print(f"[INFO] {len(features_train):,} training vectors.")
100-
print(f"[INFO] {len(features_test):,} testing vectors.")
102+
logger.info(f"{len(features_train):,} training vectors.")
103+
logger.info(f"{len(features_test):,} testing vectors.")
101104

102-
print("[INFO] Training model with training data.")
105+
logger.info("Training model with training data.")
103106
trainer = pycrfsuite.Trainer(verbose=False) # type: ignore
104107
trainer.set_params(
105108
{
@@ -117,7 +120,7 @@ def train_parser_model(
117120
trainer.append(X, y)
118121
trainer.train(str(save_model))
119122

120-
print("[INFO] Evaluating model with test data.")
123+
logger.info("Evaluating model with test data.")
121124
tagger = pycrfsuite.Tagger() # type: ignore
122125
tagger.open(str(save_model))
123126

train/training_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import concurrent.futures as cf
44
import json
5+
import logging
56
import sqlite3
67
from dataclasses import dataclass
78
from functools import partial
@@ -17,6 +18,8 @@
1718

1819
from ingredient_parser import SUPPORTED_LANGUAGES
1920

21+
logger = logging.getLogger(__name__)
22+
2023
sqlite3.register_converter("json", json.loads)
2124

2225
DEFAULT_MODEL_LOCATION = "ingredient_parser/en/model.en.crfsuite"
@@ -185,7 +188,7 @@ def load_datasets(
185188
"""
186189
PreProcessor = select_preprocessor(table)
187190

188-
print("[INFO] Loading and transforming training data.")
191+
logger.info("Loading and transforming training data.")
189192

190193
n = len(datasets)
191194
with sqlite3.connect(database, detect_types=sqlite3.PARSE_DECLTYPES) as conn:
@@ -225,8 +228,8 @@ def load_datasets(
225228
discarded=sum(v.discarded for v in vectors),
226229
)
227230

228-
print(f"[INFO] {len(all_vectors.sentences):,} usable vectors.")
229-
print(f"[INFO] {all_vectors.discarded:,} discarded due to OTHER labels.")
231+
logger.info(f"{len(all_vectors.sentences):,} usable vectors.")
232+
logger.info(f"{all_vectors.discarded:,} discarded due to OTHER labels.")
230233
return all_vectors
231234

232235

@@ -378,5 +381,5 @@ def confusion_matrix(
378381
ax.tick_params(axis="x", labelrotation=45)
379382
fig.tight_layout()
380383
fig.savefig(figure_path)
381-
print(f"[INFO] Confusion matrix saved to {figure_path}")
384+
logger.info(f"Confusion matrix saved to {figure_path}.")
382385
plt.close(fig)

0 commit comments

Comments
 (0)