Skip to content

Commit 44731da

Browse files
custom score controller added
1 parent 7de00c4 commit 44731da

File tree

4 files changed

+197
-66
lines changed

4 files changed

+197
-66
lines changed

topicnet/cooking_machine/cubes/controller_cube.py

Lines changed: 123 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
We assume that if that metric is 'sort of decreasing', then everything is OK
1919
and we are allowed to change tau coefficient further; otherwise we revert back
2020
to the last "safe" value and stop
21+
22+
'sort of decreasing' perform best with PerplexityScore.
2123
2224
More formal definition of "sort of decreasing": if we divide a curve into two parts like so:
2325
@@ -46,11 +48,17 @@
4648
| right part |
4749
4850
then the right part is no higher than 5% of global minimum
49-
(you can change 5% if you like by adjusting `fraction_threshold`
50-
in `is_score_out_of_control` function)
51+
(you can change 5% if you like by adjusting `fraction_threshold` parameter)
5152
52-
If score_to_track is None, then `ControllerAgent` will never stop
53+
If score_to_track is None and score_controller is None, then `ControllerAgent` will never stop
5354
(useful for e.g. decaying coefficients)
55+
fraction_threshold: float
56+
Threshold to control a score by 'sort of decreasing' metric
57+
score_controller: ScoreControllerBase
58+
Custom score controller
59+
In case of 'sort of decreasing' is not proper to control score, you are able to create custom Score Controller
60+
inherited from ScoreControllerBase. New controller must contain function `is_score_out_of_control`
61+
which takes TopicModel as input and return answer of type OutOfControlAnswer
5462
5563
tau_converter: str or callable
5664
Notably, def-style functions and lambda functions are allowed
@@ -111,6 +119,8 @@
111119

112120
import warnings
113121
from copy import deepcopy
122+
from dataclasses import dataclass
123+
from typing import List, Optional
114124

115125
import numexpr as ne
116126
import numpy as np
@@ -123,57 +133,83 @@
123133
W_MAX_ITERS = "Maximum number of iterations is exceeded; turning off"
124134

125135

126-
def get_two_values_diff(min_val: float, max_val: float):
127-
if min_val == 0:
128-
return max_val
136+
@dataclass
137+
class OutOfControlAnswer:
138+
answer: Optional[bool]
139+
error_message: Optional[str] = None
129140

130-
answer = (max_val-min_val)/min_val
131141

132-
return answer
142+
class ScoreControllerBase:
143+
def __init__(self, score_name):
144+
self.score_name = score_name
133145

146+
def get_vals(self, model):
147+
if self.score_name not in model.scores: # case of None is handled here as well
148+
return None
134149

135-
def is_score_out_of_control(model, score_name, fraction_threshold=0.05):
136-
"""
137-
Returns True if score isn't 'sort of decreasing' anymore.
150+
vals = model.scores[self.score_name]
151+
if len(vals) == 0:
152+
return None
153+
154+
return vals
155+
156+
def __call__(self, model):
157+
values = self.get_vals(model)
158+
159+
if values is None:
160+
return False
138161

139-
See docstring for RegularizationControllerCube for details
162+
try:
163+
out_of_control_result = self.is_out_of_control(values)
164+
except Exception as ex:
165+
message = (f"An error occured while controlling {self.score_name}. Message: {ex}. Score values: {values}")
166+
warnings.warn(message)
167+
return True
140168

141-
Parameters
142-
----------
143-
model : TopicModel
144-
score_name : str or None
145-
fraction_threshold : float
169+
if out_of_control_result.error_message is not None:
170+
warnings.warn(out_of_control_result.error_message)
146171

147-
Returns
148-
-------
149-
bool
172+
return out_of_control_result.answer
150173

174+
def is_out_of_control(self, values: List[float]) -> OutOfControlAnswer:
175+
raise NotImplementedError
176+
177+
178+
class ScoreControllerPerplexity(ScoreControllerBase):
179+
"""
180+
Controller is properto control the Perplexity score. For others, please ensure for yourself.
151181
"""
182+
DEFAULT_FRACTION_THRESHOLD = 0.05
152183

153-
if score_name not in model.scores: # case of None is handled here as well
154-
return False
184+
def __init__(self, score_name, fraction_threshold=DEFAULT_FRACTION_THRESHOLD):
185+
super().__init__(score_name)
186+
self.fraction_threshold = fraction_threshold
155187

156-
vals = model.scores[score_name]
157-
if len(vals) == 0:
158-
return False
188+
def is_out_of_control(self, values: List[float]):
189+
idxmin = np.argmin(values)
159190

191+
if idxmin == len(values): # score is monotonically decreasing
192+
return False
160193

194+
right_maxval = max(values[idxmin:])
195+
minval = values[idxmin]
161196

162-
idxmin = np.argmin(vals)
197+
if minval <= 0:
198+
message = f"""Score {self.score_name} has min_value = {minval} which is <=0.
199+
This control scheme is using to control scores acting like Perplexity.
200+
Ensure you control the Perplexity score or write your own controller"""
201+
return OutOfControlAnswer(answer=True, error_message=message)
163202

164-
if idxmin == len(vals): # score is monotonically decreasing
165-
return False
166-
maxval = max(vals[idxmin:])
167-
minval = vals[idxmin]
168-
diff = get_two_values_diff(minval, maxval)
169-
answer = diff > fraction_threshold
203+
answer = (right_maxval - minval) / minval > self.fraction_threshold
170204

171-
if answer:
172-
msg = (f"Score {score_name} is too high: during training the value {maxval}"
173-
f" passed a treshold of {(1 + fraction_threshold) * minval}"
174-
f" (estimate is based on {idxmin} iteration)")
175-
warnings.warn(msg)
176-
return answer
205+
if answer:
206+
message = (f"Score {self.score_name} is too high! Right max value: {right_maxval}, min value: {minval}")
207+
return OutOfControlAnswer(answer=answer, error_message=message)
208+
209+
return OutOfControlAnswer(answer=answer)
210+
211+
212+
class ControllerAgentException(Exception): pass
177213

178214

179215
class ControllerAgent:
@@ -185,8 +221,10 @@ class ControllerAgent:
185221
Each agent is described by:
186222
187223
* reg_name: the name of regularizer having `tau` which needs to be changed
188-
* score_to_track: score providing control of the callback execution
189224
* tau_converter: function or string describing how to get new `tau` from old `tau`
225+
* score_to_track: score name providing control of the callback execution
226+
* fraction_threshold: threshold to control score_to_track
227+
* score_controller: custom score controller providing control of the callback execution
190228
* local_dict: dictionary containing values of several variables,
191229
most notably, `user_value`
192230
* is_working:
@@ -197,31 +235,61 @@ class ControllerAgent:
197235
See top-level docstring for details.
198236
"""
199237

200-
def __init__(self, reg_name, score_to_track, tau_converter, max_iters, local_dict=None):
238+
def __init__(self, reg_name, tau_converter, max_iters, score_to_track=None, fraction_threshold=None,
239+
score_controller=None, local_dict=None):
201240
"""
202241
203242
Parameters
204243
----------
205244
reg_name : str
206-
score_to_track : str, list of str or None
207245
tau_converter : callable or str
208-
local_dict : dict
209246
max_iters : int or float
210247
Agent will stop changing tau after `max_iters` iterations
211248
`max_iters` could be `float("NaN")` and `float("inf")` values:
212249
that way agent will continue operating even outside this `RegularizationControllerCube`
250+
score_to_track : str, list of str or None
251+
Name of score to track
252+
Please, use this definition to track only scores of type PerplexityScore.
253+
In other cases we recommend you to write you own ScoreController
254+
fraction_threshold : float, list of float of the same length as score_to_track or None
255+
Uses to define threshold to control PerplexityScore
256+
Default value is 0.05
257+
score_controller : ScoreControllerBase, list of ScoreControllerBase or None
258+
local_dict : dict
213259
"""
214260
if local_dict is None:
215261
local_dict = dict()
216262

217263
self.reg_name = reg_name
218264
self.tau_converter = tau_converter
265+
266+
self.score_controllers = []
219267
if isinstance(score_to_track, list):
220-
self.score_to_track = score_to_track
268+
if fraction_threshold is None:
269+
scores = [(ScoreControllerPerplexity.DEFAULT_FRACTION_THRESHOLD, name) for name in score_to_track]
270+
elif isinstance(fraction_threshold, list) and len(score_to_track) == len(fraction_threshold):
271+
scores = list(zip(score_to_track, fraction_threshold))
272+
else:
273+
err_message = """Length of score_to_track and fraction_threshold must be same.
274+
Otherwise fraction_threshold must be None"""
275+
raise ControllerAgentException(err_message)
276+
277+
self.score_controllers.append([ScoreControllerPerplexity(name, threshold) for (name, threshold) in scores])
278+
221279
elif isinstance(score_to_track, str):
222-
self.score_to_track = [score_to_track]
223-
else:
224-
self.score_to_track = []
280+
self.score_controllers.append([ScoreControllerPerplexity(
281+
score_to_track,
282+
fraction_threshold or ScoreControllerPerplexity.DEFAULT_FRACTION_THRESHOLD
283+
)])
284+
285+
if isinstance(score_controller, ScoreControllerBase):
286+
self.score_controllers.append(score_controller)
287+
elif isinstance(score_controller, list):
288+
if not all(isinstance(score, ScoreControllerBase) for score in score_controller):
289+
err_message = """score_controller must be of type ScoreControllerBase od list of ScoreControllerBase"""
290+
raise ControllerAgentException(err_message)
291+
292+
self.score_controllers.extend(score_controller)
225293

226294
self.is_working = True
227295
self.local_dict = local_dict
@@ -272,7 +340,7 @@ def invoke(self, model, cur_iter):
272340

273341
if self.is_working:
274342
should_stop = any(
275-
is_score_out_of_control(model, score) for score in self.score_to_track
343+
score_controller(model) for score_controller in self.score_controllers
276344
)
277345
if should_stop:
278346
warnings.warn(W_HALT_CONTROL.format(len(self.tau_history)))
@@ -297,26 +365,31 @@ def __init__(self, num_iter: int, parameters,
297365
regularizers params
298366
each dict should contain the following fields:
299367
("reg_name" or "regularizer"),
300-
"score_to_track" (optional),
301368
"tau_converter",
369+
"score_to_track" (optional),
370+
"fraction_threshold" (optional),
371+
"score_controller" (optional),
302372
"user_value_grid"
303373
See top-level docstring for details.
304374
Examples:
305375
306376
>> {"regularizer": artm.regularizers.<...>,
307-
>> "score_to_track": "PerplexityScore@all",
308377
>> "tau_converter": "prev_tau * user_value",
378+
>> "score_to_track": "PerplexityScore@all",
379+
>> "fraction_threshold": 0.1,
309380
>> "user_value_grid": [0.5, 1, 2]}
310381
311382
312383
-----------
313384
314385
>> {"reg_name": "decorrelator_for_ngramms",
315-
>> "score_to_track": None,
316386
>> "tau_converter": (
317387
>> lambda initial_tau, prev_tau, cur_iter, user_value:
318388
>> initial_tau * (cur_iter % 2) + user_value
319389
>> )
390+
>> "score_to_track": None,
391+
>> "fraction_threshold": None,
392+
>> "score_controller": [ScoreControllerPerplexity("PerplexityScore@all", 0.1)],
320393
>> "user_value_grid": [0, 1]}
321394
322395
reg_search : str

topicnet/tests/fixtures.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from topicnet.cooking_machine.cubes.controller_cube import ScoreControllerPerplexity
2+
3+
DATA_REG_CONTROLLER_SORT_OF_DECREASING = [
4+
([246.77072143554688,
5+
124.72193908691406,
6+
107.95775604248047,
7+
105.27597045898438,
8+
112.46900939941406,
9+
132.88259887695312], 0.1, True),
10+
([246.77072143554688,
11+
124.72193908691406,
12+
107.95775604248047,
13+
105.27597045898438,
14+
112.46900939941406], 0.1, False),
15+
([246.77072143554688,
16+
124.72193908691406,
17+
107.95775604248047,
18+
105.27597045898438,
19+
112.46900939941406], 0.05, True),
20+
21+
]
22+
DATA_AGENT_CONTROLLER_LEN_CHECK = [
23+
({
24+
"reg_name": "decorrelation",
25+
"score_to_track": "PerplexityScore@all",
26+
"tau_converter": "prev_tau * user_value",
27+
"max_iters": float("inf")
28+
}, 1),
29+
({
30+
"reg_name": "decorrelation",
31+
"score_to_track": ["PerplexityScore@all"],
32+
"tau_converter": "prev_tau + user_value",
33+
"max_iters": float("inf")
34+
}, 1),
35+
({
36+
"reg_name": "decorrelation",
37+
"score_to_track": None, # never stop working
38+
"tau_converter": "prev_tau * user_value",
39+
"max_iters": float("inf")
40+
}, 0),
41+
({
42+
"reg_name": "decorrelation",
43+
"score_to_track": None, # never stop working
44+
"score_controller": ScoreControllerPerplexity("PerplexityScore@all", 0.1),
45+
"tau_converter": "prev_tau * user_value",
46+
"max_iters": float("inf")
47+
}, 1),
48+
({
49+
"reg_name": "decorrelation",
50+
"score_to_track": "PerplexityScore@all", # never stop working
51+
"score_controller": ScoreControllerPerplexity("PerplexityScore@all", 0.1),
52+
"tau_converter": "prev_tau * user_value",
53+
"max_iters": float("inf")
54+
}, 2)
55+
]

topicnet/tests/test_cube_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
from topicnet.cooking_machine.cubes.controller_cube import ScoreControllerPerplexity, ControllerAgent
4+
from .fixtures import DATA_AGENT_CONTROLLER_LEN_CHECK, DATA_REG_CONTROLLER_SORT_OF_DECREASING
5+
6+
7+
@pytest.mark.parametrize('values, fraction, answer_true', DATA_REG_CONTROLLER_SORT_OF_DECREASING)
8+
def test_get_two_values_diff(values, fraction, answer_true):
9+
score_controller = ScoreControllerPerplexity('test', fraction)
10+
is_out_of_control = score_controller.is_out_of_control(values)
11+
12+
assert is_out_of_control.answer == answer_true
13+
14+
15+
@pytest.mark.parametrize('agent_blueprint, answer_true', DATA_AGENT_CONTROLLER_LEN_CHECK)
16+
def test_controllers_length(agent_blueprint, answer_true):
17+
agent = ControllerAgent(**agent_blueprint)
18+
19+
assert len(agent.score_controllers) == answer_true

topicnet/tests/test_data/test_cube_utils.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)