Skip to content

Commit 5872ade

Browse files
committed
Enable parallelization during text stemming
1 parent 8a9f757 commit 5872ade

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

quantulum3/classifier.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import json
88
import logging
99
import pkg_resources
10+
import os
11+
import multiprocessing
1012

1113
# Semi-dependencies
1214
try:
@@ -106,25 +108,36 @@ def clean_text(text, lang='en_US'):
106108
return _get_classifier(lang).clean_text(text)
107109

108110

111+
def _clean_text_lang(lang):
112+
return _get_classifier(lang).clean_text
113+
114+
109115
###############################################################################
110116
def train_classifier(parameters=None,
111117
ngram_range=(1, 1),
112118
store=True,
113-
lang='en_US'):
119+
lang='en_US',
120+
n_jobs=len(os.sched_getaffinity(0))):
114121
"""
115122
Train the intent classifier
116123
TODO auto invoke if sklearn version is new or first install or sth
117124
@:param store (bool) store classifier in clf.joblib
118125
"""
126+
_LOGGER.info("Started training, parallelized with {} jobs".format(n_jobs))
119127
_LOGGER.info("Loading training set")
120128
training_set = load.training_set(lang)
121129
target_names = list(frozenset([i['unit'] for i in training_set]))
122130

123131
_LOGGER.info("Preparing training set")
124-
train_data, train_target = [], []
125-
for example in training_set:
126-
train_data.append(clean_text(example['text'], lang))
127-
train_target.append(target_names.index(example['unit']))
132+
133+
if n_jobs > 1:
134+
with multiprocessing.Pool(processes=n_jobs) as p:
135+
train_data = p.map(_clean_text_lang(lang), [ex['text'] for ex in training_set])
136+
else:
137+
# This allows for classifier training in the interactive python shell
138+
train_data = [_clean_text_lang(lang)(ex['text']) for ex in training_set]
139+
140+
train_target = [target_names.index(example['unit']) for example in training_set]
128141

129142
tfidf_model = TfidfVectorizer(
130143
sublinear_tf=True,
@@ -139,7 +152,7 @@ def train_classifier(parameters=None,
139152
'loss': 'log',
140153
'penalty': 'l2',
141154
'tol': 1e-3,
142-
'n_jobs': -1,
155+
'n_jobs': n_jobs,
143156
'alpha': 0.0001,
144157
'fit_intercept': True,
145158
'random_state': 0,

0 commit comments

Comments
 (0)