Skip to content

Commit 15e6e30

Browse files
Merge pull request #52 from oksanadanilova/feature/score_out_of_control_bugfix
right percentage distance + tests
2 parents e3cab00 + 3d03f63 commit 15e6e30

File tree

2 files changed

+210
-50
lines changed

2 files changed

+210
-50
lines changed

topicnet/cooking_machine/cubes/controller_cube.py

+138-50
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
We assume that if that metric is 'sort of decreasing', then everything is OK
1919
and we are allowed to change tau coefficient further; otherwise we revert back
2020
to the last "safe" value and stop
21+
22+
'sort of decreasing' performs best with `PerplexityScore`, and all scores which
23+
behave like perplexity (nonnegative, and which should decrease when a model gets better).
24+
If you want to track a different kind of score, it is recommended to use `score_controller` parameter
2125
2226
More formal definition of "sort of decreasing": if we divide a curve into two parts like so:
2327
@@ -46,12 +50,16 @@
4650
| right part |
4751
4852
then the right part is no higher than 5% of global minimum
49-
(you can change 5% if you like by adjusting `fraction_threshold`
50-
in `is_score_out_of_control` function)
53+
(you can change 5% if you like by adjusting `fraction_threshold` parameter)
5154
52-
If score_to_track is None, then `ControllerAgent` will never stop
55+
If `score_to_track` is None and `score_controller` is None, then `ControllerAgent` will never stop
5356
(useful for e.g. decaying coefficients)
54-
57+
fraction_threshold: float
58+
Threshold to control a score by 'sort of decreasing' metric
59+
score_controller: BaseScoreController
60+
Custom score controller
61+
In case of 'sort of decreasing' is not proper to control score, you are able to create custom Score Controller
62+
inherited from `BaseScoreController`.
5563
tau_converter: str or callable
5664
Notably, def-style functions and lambda functions are allowed
5765
If it is function, then it should accept four arguments:
@@ -109,58 +117,98 @@
109117
that way agent will continue operating even outside this `RegularizationControllerCube`
110118
""" # noqa: W291
111119

112-
from .base_cube import BaseCube
113-
from ..rel_toolbox_lite import count_vocab_size, handle_regularizer
114-
115-
import numexpr as ne
116120
import warnings
117-
from dill.source import getsource
118121
from copy import deepcopy
122+
from dataclasses import dataclass
123+
from typing import List, Optional
124+
125+
import numexpr as ne
119126
import numpy as np
127+
from dill.source import getsource
120128

129+
from .base_cube import BaseCube
130+
from ..rel_toolbox_lite import count_vocab_size, handle_regularizer
121131

122132
W_HALT_CONTROL = "Process of dynamically changing tau was stopped at {} iteration"
123133
W_MAX_ITERS = "Maximum number of iterations is exceeded; turning off"
124134

125135

126-
def is_score_out_of_control(model, score_name, fraction_threshold=0.05):
127-
"""
128-
Returns True if score isn't 'sort of decreasing' anymore.
136+
@dataclass
137+
class OutOfControlAnswer:
138+
answer: bool
139+
error_message: Optional[str] = None
129140

130-
See docstring for RegularizationControllerCube for details
131141

132-
Parameters
133-
----------
134-
model : TopicModel
135-
score_name : str or None
136-
fraction_threshold : float
142+
class BaseScoreController:
143+
def __init__(self, score_name):
144+
self.score_name = score_name
137145

138-
Returns
139-
-------
140-
bool
146+
def get_score_values(self, model):
147+
if self.score_name not in model.scores: # case of None is handled here as well
148+
return None
141149

150+
vals = model.scores[self.score_name]
151+
if len(vals) == 0:
152+
return None
153+
154+
return vals
155+
156+
def __call__(self, model):
157+
values = self.get_score_values(model)
158+
159+
if values is None:
160+
return False
161+
162+
try:
163+
out_of_control_result = self.is_out_of_control(values)
164+
except Exception as ex:
165+
message = (f"An error occured while controlling {self.score_name}. Message: {ex}. Score values: {values}")
166+
raise ValueError(message)
167+
168+
if out_of_control_result.error_message is not None:
169+
warnings.warn(out_of_control_result.error_message)
170+
171+
return out_of_control_result.answer
172+
173+
def is_out_of_control(self, values: List[float]) -> OutOfControlAnswer:
174+
raise NotImplementedError
175+
176+
177+
class PerplexityScoreController(BaseScoreController):
142178
"""
179+
Controller is proper to control the Perplexity score. For others, please ensure for yourself.
180+
"""
181+
DEFAULT_FRACTION_THRESHOLD = 0.05
182+
183+
def __init__(self, score_name, fraction_threshold=DEFAULT_FRACTION_THRESHOLD):
184+
super().__init__(score_name)
185+
self.fraction_threshold = fraction_threshold
143186

144-
if score_name not in model.scores: # case of None is handled here as well
145-
return False
187+
def is_out_of_control(self, values: List[float]):
188+
idxmin = np.argmin(values)
146189

147-
vals = model.scores[score_name]
148-
if len(vals) == 0:
149-
return False
190+
if idxmin == len(values): # score is monotonically decreasing
191+
return False
150192

151-
idxmin = np.argmin(vals)
193+
right_maxval = max(values[idxmin:])
194+
minval = values[idxmin]
152195

153-
if idxmin == len(vals): # score is monotonically decreasing
154-
return False
155-
maxval = max(vals[idxmin:])
156-
minval = vals[idxmin]
157-
answer = ((maxval - minval)/abs(minval) - 1.0) > fraction_threshold
158-
if answer:
159-
msg = (f"Score {score_name} is too high: during training the value {maxval}"
160-
f" passed a treshold of {(1 + fraction_threshold) * minval}"
161-
f" (estimate is based on {idxmin} iteration)")
162-
warnings.warn(msg)
163-
return answer
196+
if minval <= 0:
197+
err_message = f"""Score {self.score_name} has min_value = {minval} which is <= 0.
198+
This control scheme is using to control scores acting like Perplexity.
199+
Ensure you control the Perplexity score or write your own controller"""
200+
raise ValueError(err_message)
201+
202+
answer = (right_maxval - minval) / minval > self.fraction_threshold
203+
204+
if answer:
205+
message = (f"Score {self.score_name} is too high! Right max value: {right_maxval}, min value: {minval}")
206+
return OutOfControlAnswer(answer=answer, error_message=message)
207+
208+
return OutOfControlAnswer(answer=answer)
209+
210+
211+
class ControllerAgentException(Exception): pass
164212

165213

166214
class ControllerAgent:
@@ -172,8 +220,10 @@ class ControllerAgent:
172220
Each agent is described by:
173221
174222
* reg_name: the name of regularizer having `tau` which needs to be changed
175-
* score_to_track: score providing control of the callback execution
176223
* tau_converter: function or string describing how to get new `tau` from old `tau`
224+
* score_to_track: score name providing control of the callback execution
225+
* fraction_threshold: threshold to control score_to_track
226+
* score_controller: custom score controller providing control of the callback execution
177227
* local_dict: dictionary containing values of several variables,
178228
most notably, `user_value`
179229
* is_working:
@@ -183,31 +233,64 @@ class ControllerAgent:
183233
184234
See top-level docstring for details.
185235
"""
186-
def __init__(self, reg_name, score_to_track, tau_converter, max_iters, local_dict=None):
236+
237+
def __init__(self, reg_name, tau_converter, max_iters, score_to_track=None, fraction_threshold=None,
238+
score_controller=None, local_dict=None):
187239
"""
188240
189241
Parameters
190242
----------
191243
reg_name : str
192-
score_to_track : str, list of str or None
193244
tau_converter : callable or str
194-
local_dict : dict
195245
max_iters : int or float
196246
Agent will stop changing tau after `max_iters` iterations
197247
`max_iters` could be `float("NaN")` and `float("inf")` values:
198248
that way agent will continue operating even outside this `RegularizationControllerCube`
249+
score_to_track : str, list of str or None
250+
Name of score to track.
251+
Please, use this definition to track only scores of type PerplexityScore.
252+
In other cases we recommend you to write you own ScoreController
253+
fraction_threshold : float, list of float of the same length as score_to_track or None
254+
Uses to define threshold to control PerplexityScore
255+
Default value is 0.05
256+
score_controller : BaseScoreController, list of BaseScoreController or None
257+
local_dict : dict
199258
"""
200259
if local_dict is None:
201260
local_dict = dict()
202261

203262
self.reg_name = reg_name
204263
self.tau_converter = tau_converter
264+
265+
self.score_controllers = []
205266
if isinstance(score_to_track, list):
206-
self.score_to_track = score_to_track
267+
if fraction_threshold is None:
268+
controller_params = [(name, PerplexityScoreController.DEFAULT_FRACTION_THRESHOLD) for name in
269+
score_to_track]
270+
elif isinstance(fraction_threshold, list) and len(score_to_track) == len(fraction_threshold):
271+
controller_params = list(zip(score_to_track, fraction_threshold))
272+
else:
273+
err_message = """Length of score_to_track and fraction_threshold must be same.
274+
Otherwise fraction_threshold must be None"""
275+
raise ControllerAgentException(err_message)
276+
277+
self.score_controllers.append(
278+
[PerplexityScoreController(name, threshold) for (name, threshold) in controller_params])
279+
207280
elif isinstance(score_to_track, str):
208-
self.score_to_track = [score_to_track]
209-
else:
210-
self.score_to_track = []
281+
self.score_controllers.append([PerplexityScoreController(
282+
score_to_track,
283+
fraction_threshold or PerplexityScoreController.DEFAULT_FRACTION_THRESHOLD
284+
)])
285+
286+
if isinstance(score_controller, BaseScoreController):
287+
self.score_controllers.append(score_controller)
288+
elif isinstance(score_controller, list):
289+
if not all(isinstance(score, BaseScoreController) for score in score_controller):
290+
err_message = """score_controller must be of type BaseScoreController or list of BaseScoreController"""
291+
raise ControllerAgentException(err_message)
292+
293+
self.score_controllers.extend(score_controller)
211294

212295
self.is_working = True
213296
self.local_dict = local_dict
@@ -258,7 +341,7 @@ def invoke(self, model, cur_iter):
258341

259342
if self.is_working:
260343
should_stop = any(
261-
is_score_out_of_control(model, score) for score in self.score_to_track
344+
score_controller(model) for score_controller in self.score_controllers
262345
)
263346
if should_stop:
264347
warnings.warn(W_HALT_CONTROL.format(len(self.tau_history)))
@@ -283,26 +366,31 @@ def __init__(self, num_iter: int, parameters,
283366
regularizers params
284367
each dict should contain the following fields:
285368
("reg_name" or "regularizer"),
286-
"score_to_track" (optional),
287369
"tau_converter",
370+
"score_to_track" (optional),
371+
"fraction_threshold" (optional),
372+
"score_controller" (optional),
288373
"user_value_grid"
289374
See top-level docstring for details.
290375
Examples:
291376
292377
>> {"regularizer": artm.regularizers.<...>,
293-
>> "score_to_track": "PerplexityScore@all",
294378
>> "tau_converter": "prev_tau * user_value",
379+
>> "score_to_track": "PerplexityScore@all",
380+
>> "fraction_threshold": 0.1,
295381
>> "user_value_grid": [0.5, 1, 2]}
296382
297383
298384
-----------
299385
300386
>> {"reg_name": "decorrelator_for_ngramms",
301-
>> "score_to_track": None,
302387
>> "tau_converter": (
303388
>> lambda initial_tau, prev_tau, cur_iter, user_value:
304389
>> initial_tau * (cur_iter % 2) + user_value
305390
>> )
391+
>> "score_to_track": None,
392+
>> "fraction_threshold": None,
393+
>> "score_controller": [PerplexityScoreController("PerplexityScore@all", 0.1)],
306394
>> "user_value_grid": [0, 1]}
307395
308396
reg_search : str

topicnet/tests/test_cube_utils.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
3+
from topicnet.cooking_machine.cubes.controller_cube import PerplexityScoreController, ControllerAgent
4+
5+
DATA_REG_CONTROLLER_SORT_OF_DECREASING = [
6+
([246.77072143554688,
7+
124.72193908691406,
8+
107.95775604248047,
9+
105.27597045898438,
10+
112.46900939941406,
11+
132.88259887695312], 0.1, True),
12+
([246.77072143554688,
13+
124.72193908691406,
14+
107.95775604248047,
15+
105.27597045898438,
16+
112.46900939941406], 0.1, False),
17+
([246.77072143554688,
18+
124.72193908691406,
19+
107.95775604248047,
20+
105.27597045898438,
21+
112.46900939941406], 0.05, True),
22+
23+
]
24+
DATA_AGENT_CONTROLLER_LEN_CHECK = [
25+
({
26+
"reg_name": "decorrelation",
27+
"score_to_track": "PerplexityScore@all",
28+
"tau_converter": "prev_tau * user_value",
29+
"max_iters": float("inf")
30+
}, 1),
31+
({
32+
"reg_name": "decorrelation",
33+
"score_to_track": ["PerplexityScore@all"],
34+
"tau_converter": "prev_tau + user_value",
35+
"max_iters": float("inf")
36+
}, 1),
37+
({
38+
"reg_name": "decorrelation",
39+
"score_to_track": None, # never stop working
40+
"tau_converter": "prev_tau * user_value",
41+
"max_iters": float("inf")
42+
}, 0),
43+
({
44+
"reg_name": "decorrelation",
45+
"score_to_track": None, # never stop working
46+
"score_controller": PerplexityScoreController("PerplexityScore@all", 0.1),
47+
"tau_converter": "prev_tau * user_value",
48+
"max_iters": float("inf")
49+
}, 1),
50+
({
51+
"reg_name": "decorrelation",
52+
"score_to_track": "PerplexityScore@all", # never stop working
53+
"score_controller": PerplexityScoreController("PerplexityScore@all", 0.1),
54+
"tau_converter": "prev_tau * user_value",
55+
"max_iters": float("inf")
56+
}, 2)
57+
]
58+
59+
60+
@pytest.mark.parametrize('values, fraction, answer_true', DATA_REG_CONTROLLER_SORT_OF_DECREASING)
61+
def test_perplexity_controller(values, fraction, answer_true):
62+
score_controller = PerplexityScoreController('test', fraction)
63+
is_out_of_control = score_controller.is_out_of_control(values)
64+
65+
assert is_out_of_control.answer == answer_true
66+
67+
68+
@pytest.mark.parametrize('agent_blueprint, answer_true', DATA_AGENT_CONTROLLER_LEN_CHECK)
69+
def test_controllers_length(agent_blueprint, answer_true):
70+
agent = ControllerAgent(**agent_blueprint)
71+
72+
assert len(agent.score_controllers) == answer_true

0 commit comments

Comments
 (0)