19
19
and we are allowed to change tau coefficient further; otherwise we revert back
20
20
to the last "safe" value and stop
21
21
22
- 'sort of decreasing' perform best with PerplexityScore.
22
+ 'sort of decreasing' performs best with `PerplexityScore`, and all scores which
23
+ behave like perplexity (nonnegative, and which should decrease when a model gets better).
24
+ If you want to track a different kind of score, it is recommended to use `score_controller` parameter
23
25
24
26
More formal definition of "sort of decreasing": if we divide a curve into two parts like so:
25
27
50
52
then the right part is no higher than 5% of global minimum
51
53
(you can change 5% if you like by adjusting `fraction_threshold` parameter)
52
54
53
- If score_to_track is None and score_controller is None, then `ControllerAgent` will never stop
55
+ If ` score_to_track` is None and ` score_controller` is None, then `ControllerAgent` will never stop
54
56
(useful for e.g. decaying coefficients)
55
57
fraction_threshold: float
56
58
Threshold to control a score by 'sort of decreasing' metric
57
- score_controller: ScoreControllerBase
59
+ score_controller: BaseScoreController
58
60
Custom score controller
59
61
In case of 'sort of decreasing' is not proper to control score, you are able to create custom Score Controller
60
- inherited from ScoreControllerBase. New controller must contain function `is_score_out_of_control`
61
- which takes TopicModel as input and return answer of type OutOfControlAnswer
62
-
62
+ inherited from `BaseScoreController`.
63
63
tau_converter: str or callable
64
64
Notably, def-style functions and lambda functions are allowed
65
65
If it is function, then it should accept four arguments:
135
135
136
136
@dataclass
137
137
class OutOfControlAnswer :
138
- answer : Optional [ bool ]
138
+ answer : bool
139
139
error_message : Optional [str ] = None
140
140
141
141
142
- class ScoreControllerBase :
142
+ class BaseScoreController :
143
143
def __init__ (self , score_name ):
144
144
self .score_name = score_name
145
145
146
- def get_vals (self , model ):
146
+ def get_score_values (self , model ):
147
147
if self .score_name not in model .scores : # case of None is handled here as well
148
148
return None
149
149
@@ -154,7 +154,7 @@ def get_vals(self, model):
154
154
return vals
155
155
156
156
def __call__ (self , model ):
157
- values = self .get_vals (model )
157
+ values = self .get_score_values (model )
158
158
159
159
if values is None :
160
160
return False
@@ -163,8 +163,7 @@ def __call__(self, model):
163
163
out_of_control_result = self .is_out_of_control (values )
164
164
except Exception as ex :
165
165
message = (f"An error occured while controlling { self .score_name } . Message: { ex } . Score values: { values } " )
166
- warnings .warn (message )
167
- return True
166
+ raise ValueError (message )
168
167
169
168
if out_of_control_result .error_message is not None :
170
169
warnings .warn (out_of_control_result .error_message )
@@ -175,9 +174,9 @@ def is_out_of_control(self, values: List[float]) -> OutOfControlAnswer:
175
174
raise NotImplementedError
176
175
177
176
178
- class ScoreControllerPerplexity ( ScoreControllerBase ):
177
+ class PerplexityScoreController ( BaseScoreController ):
179
178
"""
180
- Controller is properto control the Perplexity score. For others, please ensure for yourself.
179
+ Controller is proper to control the Perplexity score. For others, please ensure for yourself.
181
180
"""
182
181
DEFAULT_FRACTION_THRESHOLD = 0.05
183
182
@@ -195,10 +194,10 @@ def is_out_of_control(self, values: List[float]):
195
194
minval = values [idxmin ]
196
195
197
196
if minval <= 0 :
198
- message = f"""Score { self .score_name } has min_value = { minval } which is <=0.
197
+ err_message = f"""Score { self .score_name } has min_value = { minval } which is <= 0.
199
198
This control scheme is using to control scores acting like Perplexity.
200
199
Ensure you control the Perplexity score or write your own controller"""
201
- return OutOfControlAnswer ( answer = True , error_message = message )
200
+ raise ValueError ( err_message )
202
201
203
202
answer = (right_maxval - minval ) / minval > self .fraction_threshold
204
203
@@ -248,13 +247,13 @@ def __init__(self, reg_name, tau_converter, max_iters, score_to_track=None, frac
248
247
`max_iters` could be `float("NaN")` and `float("inf")` values:
249
248
that way agent will continue operating even outside this `RegularizationControllerCube`
250
249
score_to_track : str, list of str or None
251
- Name of score to track
250
+ Name of score to track.
252
251
Please, use this definition to track only scores of type PerplexityScore.
253
252
In other cases we recommend you to write you own ScoreController
254
253
fraction_threshold : float, list of float of the same length as score_to_track or None
255
254
Uses to define threshold to control PerplexityScore
256
255
Default value is 0.05
257
- score_controller : ScoreControllerBase , list of ScoreControllerBase or None
256
+ score_controller : BaseScoreController , list of BaseScoreController or None
258
257
local_dict : dict
259
258
"""
260
259
if local_dict is None :
@@ -266,27 +265,29 @@ def __init__(self, reg_name, tau_converter, max_iters, score_to_track=None, frac
266
265
self .score_controllers = []
267
266
if isinstance (score_to_track , list ):
268
267
if fraction_threshold is None :
269
- scores = [(ScoreControllerPerplexity .DEFAULT_FRACTION_THRESHOLD , name ) for name in score_to_track ]
268
+ controller_params = [(name , PerplexityScoreController .DEFAULT_FRACTION_THRESHOLD ) for name in
269
+ score_to_track ]
270
270
elif isinstance (fraction_threshold , list ) and len (score_to_track ) == len (fraction_threshold ):
271
- scores = list (zip (score_to_track , fraction_threshold ))
271
+ controller_params = list (zip (score_to_track , fraction_threshold ))
272
272
else :
273
273
err_message = """Length of score_to_track and fraction_threshold must be same.
274
274
Otherwise fraction_threshold must be None"""
275
275
raise ControllerAgentException (err_message )
276
276
277
- self .score_controllers .append ([ScoreControllerPerplexity (name , threshold ) for (name , threshold ) in scores ])
277
+ self .score_controllers .append (
278
+ [PerplexityScoreController (name , threshold ) for (name , threshold ) in controller_params ])
278
279
279
280
elif isinstance (score_to_track , str ):
280
- self .score_controllers .append ([ScoreControllerPerplexity (
281
+ self .score_controllers .append ([PerplexityScoreController (
281
282
score_to_track ,
282
- fraction_threshold or ScoreControllerPerplexity .DEFAULT_FRACTION_THRESHOLD
283
+ fraction_threshold or PerplexityScoreController .DEFAULT_FRACTION_THRESHOLD
283
284
)])
284
285
285
- if isinstance (score_controller , ScoreControllerBase ):
286
+ if isinstance (score_controller , BaseScoreController ):
286
287
self .score_controllers .append (score_controller )
287
288
elif isinstance (score_controller , list ):
288
- if not all (isinstance (score , ScoreControllerBase ) for score in score_controller ):
289
- err_message = """score_controller must be of type ScoreControllerBase od list of ScoreControllerBase """
289
+ if not all (isinstance (score , BaseScoreController ) for score in score_controller ):
290
+ err_message = """score_controller must be of type BaseScoreController or list of BaseScoreController """
290
291
raise ControllerAgentException (err_message )
291
292
292
293
self .score_controllers .extend (score_controller )
@@ -389,7 +390,7 @@ def __init__(self, num_iter: int, parameters,
389
390
>> )
390
391
>> "score_to_track": None,
391
392
>> "fraction_threshold": None,
392
- >> "score_controller": [ScoreControllerPerplexity ("PerplexityScore@all", 0.1)],
393
+ >> "score_controller": [PerplexityScoreController ("PerplexityScore@all", 0.1)],
393
394
>> "user_value_grid": [0, 1]}
394
395
395
396
reg_search : str
0 commit comments