Skip to content

Commit 3d03f63

Browse files
refactoring
1 parent 44731da commit 3d03f63

File tree

3 files changed

+85
-86
lines changed

3 files changed

+85
-86
lines changed

topicnet/cooking_machine/cubes/controller_cube.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
and we are allowed to change tau coefficient further; otherwise we revert back
2020
to the last "safe" value and stop
2121
22-
'sort of decreasing' perform best with PerplexityScore.
22+
'sort of decreasing' performs best with `PerplexityScore`, and all scores which
23+
behave like perplexity (nonnegative, and which should decrease when a model gets better).
24+
If you want to track a different kind of score, it is recommended to use `score_controller` parameter
2325
2426
More formal definition of "sort of decreasing": if we divide a curve into two parts like so:
2527
@@ -50,16 +52,14 @@
5052
then the right part is no higher than 5% of global minimum
5153
(you can change 5% if you like by adjusting `fraction_threshold` parameter)
5254
53-
If score_to_track is None and score_controller is None, then `ControllerAgent` will never stop
55+
If `score_to_track` is None and `score_controller` is None, then `ControllerAgent` will never stop
5456
(useful for e.g. decaying coefficients)
5557
fraction_threshold: float
5658
Threshold to control a score by 'sort of decreasing' metric
57-
score_controller: ScoreControllerBase
59+
score_controller: BaseScoreController
5860
Custom score controller
5961
In case of 'sort of decreasing' is not proper to control score, you are able to create custom Score Controller
60-
inherited from ScoreControllerBase. New controller must contain function `is_score_out_of_control`
61-
which takes TopicModel as input and return answer of type OutOfControlAnswer
62-
62+
inherited from `BaseScoreController`.
6363
tau_converter: str or callable
6464
Notably, def-style functions and lambda functions are allowed
6565
If it is function, then it should accept four arguments:
@@ -135,15 +135,15 @@
135135

136136
@dataclass
137137
class OutOfControlAnswer:
138-
answer: Optional[bool]
138+
answer: bool
139139
error_message: Optional[str] = None
140140

141141

142-
class ScoreControllerBase:
142+
class BaseScoreController:
143143
def __init__(self, score_name):
144144
self.score_name = score_name
145145

146-
def get_vals(self, model):
146+
def get_score_values(self, model):
147147
if self.score_name not in model.scores: # case of None is handled here as well
148148
return None
149149

@@ -154,7 +154,7 @@ def get_vals(self, model):
154154
return vals
155155

156156
def __call__(self, model):
157-
values = self.get_vals(model)
157+
values = self.get_score_values(model)
158158

159159
if values is None:
160160
return False
@@ -163,8 +163,7 @@ def __call__(self, model):
163163
out_of_control_result = self.is_out_of_control(values)
164164
except Exception as ex:
165165
message = (f"An error occured while controlling {self.score_name}. Message: {ex}. Score values: {values}")
166-
warnings.warn(message)
167-
return True
166+
raise ValueError(message)
168167

169168
if out_of_control_result.error_message is not None:
170169
warnings.warn(out_of_control_result.error_message)
@@ -175,9 +174,9 @@ def is_out_of_control(self, values: List[float]) -> OutOfControlAnswer:
175174
raise NotImplementedError
176175

177176

178-
class ScoreControllerPerplexity(ScoreControllerBase):
177+
class PerplexityScoreController(BaseScoreController):
179178
"""
180-
Controller is properto control the Perplexity score. For others, please ensure for yourself.
179+
Controller is proper to control the Perplexity score. For others, please ensure for yourself.
181180
"""
182181
DEFAULT_FRACTION_THRESHOLD = 0.05
183182

@@ -195,10 +194,10 @@ def is_out_of_control(self, values: List[float]):
195194
minval = values[idxmin]
196195

197196
if minval <= 0:
198-
message = f"""Score {self.score_name} has min_value = {minval} which is <=0.
197+
err_message = f"""Score {self.score_name} has min_value = {minval} which is <= 0.
199198
This control scheme is using to control scores acting like Perplexity.
200199
Ensure you control the Perplexity score or write your own controller"""
201-
return OutOfControlAnswer(answer=True, error_message=message)
200+
raise ValueError(err_message)
202201

203202
answer = (right_maxval - minval) / minval > self.fraction_threshold
204203

@@ -248,13 +247,13 @@ def __init__(self, reg_name, tau_converter, max_iters, score_to_track=None, frac
248247
`max_iters` could be `float("NaN")` and `float("inf")` values:
249248
that way agent will continue operating even outside this `RegularizationControllerCube`
250249
score_to_track : str, list of str or None
251-
Name of score to track
250+
Name of score to track.
252251
Please, use this definition to track only scores of type PerplexityScore.
253252
In other cases we recommend you to write you own ScoreController
254253
fraction_threshold : float, list of float of the same length as score_to_track or None
255254
Uses to define threshold to control PerplexityScore
256255
Default value is 0.05
257-
score_controller : ScoreControllerBase, list of ScoreControllerBase or None
256+
score_controller : BaseScoreController, list of BaseScoreController or None
258257
local_dict : dict
259258
"""
260259
if local_dict is None:
@@ -266,27 +265,29 @@ def __init__(self, reg_name, tau_converter, max_iters, score_to_track=None, frac
266265
self.score_controllers = []
267266
if isinstance(score_to_track, list):
268267
if fraction_threshold is None:
269-
scores = [(ScoreControllerPerplexity.DEFAULT_FRACTION_THRESHOLD, name) for name in score_to_track]
268+
controller_params = [(name, PerplexityScoreController.DEFAULT_FRACTION_THRESHOLD) for name in
269+
score_to_track]
270270
elif isinstance(fraction_threshold, list) and len(score_to_track) == len(fraction_threshold):
271-
scores = list(zip(score_to_track, fraction_threshold))
271+
controller_params = list(zip(score_to_track, fraction_threshold))
272272
else:
273273
err_message = """Length of score_to_track and fraction_threshold must be same.
274274
Otherwise fraction_threshold must be None"""
275275
raise ControllerAgentException(err_message)
276276

277-
self.score_controllers.append([ScoreControllerPerplexity(name, threshold) for (name, threshold) in scores])
277+
self.score_controllers.append(
278+
[PerplexityScoreController(name, threshold) for (name, threshold) in controller_params])
278279

279280
elif isinstance(score_to_track, str):
280-
self.score_controllers.append([ScoreControllerPerplexity(
281+
self.score_controllers.append([PerplexityScoreController(
281282
score_to_track,
282-
fraction_threshold or ScoreControllerPerplexity.DEFAULT_FRACTION_THRESHOLD
283+
fraction_threshold or PerplexityScoreController.DEFAULT_FRACTION_THRESHOLD
283284
)])
284285

285-
if isinstance(score_controller, ScoreControllerBase):
286+
if isinstance(score_controller, BaseScoreController):
286287
self.score_controllers.append(score_controller)
287288
elif isinstance(score_controller, list):
288-
if not all(isinstance(score, ScoreControllerBase) for score in score_controller):
289-
err_message = """score_controller must be of type ScoreControllerBase od list of ScoreControllerBase"""
289+
if not all(isinstance(score, BaseScoreController) for score in score_controller):
290+
err_message = """score_controller must be of type BaseScoreController or list of BaseScoreController"""
290291
raise ControllerAgentException(err_message)
291292

292293
self.score_controllers.extend(score_controller)
@@ -389,7 +390,7 @@ def __init__(self, num_iter: int, parameters,
389390
>> )
390391
>> "score_to_track": None,
391392
>> "fraction_threshold": None,
392-
>> "score_controller": [ScoreControllerPerplexity("PerplexityScore@all", 0.1)],
393+
>> "score_controller": [PerplexityScoreController("PerplexityScore@all", 0.1)],
393394
>> "user_value_grid": [0, 1]}
394395
395396
reg_search : str

topicnet/tests/fixtures.py

-55
This file was deleted.

topicnet/tests/test_cube_utils.py

+57-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,65 @@
11
import pytest
22

3-
from topicnet.cooking_machine.cubes.controller_cube import ScoreControllerPerplexity, ControllerAgent
4-
from .fixtures import DATA_AGENT_CONTROLLER_LEN_CHECK, DATA_REG_CONTROLLER_SORT_OF_DECREASING
3+
from topicnet.cooking_machine.cubes.controller_cube import PerplexityScoreController, ControllerAgent
4+
5+
DATA_REG_CONTROLLER_SORT_OF_DECREASING = [
6+
([246.77072143554688,
7+
124.72193908691406,
8+
107.95775604248047,
9+
105.27597045898438,
10+
112.46900939941406,
11+
132.88259887695312], 0.1, True),
12+
([246.77072143554688,
13+
124.72193908691406,
14+
107.95775604248047,
15+
105.27597045898438,
16+
112.46900939941406], 0.1, False),
17+
([246.77072143554688,
18+
124.72193908691406,
19+
107.95775604248047,
20+
105.27597045898438,
21+
112.46900939941406], 0.05, True),
22+
23+
]
24+
DATA_AGENT_CONTROLLER_LEN_CHECK = [
25+
({
26+
"reg_name": "decorrelation",
27+
"score_to_track": "PerplexityScore@all",
28+
"tau_converter": "prev_tau * user_value",
29+
"max_iters": float("inf")
30+
}, 1),
31+
({
32+
"reg_name": "decorrelation",
33+
"score_to_track": ["PerplexityScore@all"],
34+
"tau_converter": "prev_tau + user_value",
35+
"max_iters": float("inf")
36+
}, 1),
37+
({
38+
"reg_name": "decorrelation",
39+
"score_to_track": None, # never stop working
40+
"tau_converter": "prev_tau * user_value",
41+
"max_iters": float("inf")
42+
}, 0),
43+
({
44+
"reg_name": "decorrelation",
45+
"score_to_track": None, # never stop working
46+
"score_controller": PerplexityScoreController("PerplexityScore@all", 0.1),
47+
"tau_converter": "prev_tau * user_value",
48+
"max_iters": float("inf")
49+
}, 1),
50+
({
51+
"reg_name": "decorrelation",
52+
"score_to_track": "PerplexityScore@all", # never stop working
53+
"score_controller": PerplexityScoreController("PerplexityScore@all", 0.1),
54+
"tau_converter": "prev_tau * user_value",
55+
"max_iters": float("inf")
56+
}, 2)
57+
]
558

659

760
@pytest.mark.parametrize('values, fraction, answer_true', DATA_REG_CONTROLLER_SORT_OF_DECREASING)
8-
def test_get_two_values_diff(values, fraction, answer_true):
9-
score_controller = ScoreControllerPerplexity('test', fraction)
61+
def test_perplexity_controller(values, fraction, answer_true):
62+
score_controller = PerplexityScoreController('test', fraction)
1063
is_out_of_control = score_controller.is_out_of_control(values)
1164

1265
assert is_out_of_control.answer == answer_true

0 commit comments

Comments
 (0)