18
18
We assume that if that metric is 'sort of decreasing', then everything is OK
19
19
and we are allowed to change tau coefficient further; otherwise we revert back
20
20
to the last "safe" value and stop
21
+
22
+ 'sort of decreasing' performs best with `PerplexityScore`, and all scores which
23
+ behave like perplexity (nonnegative, and which should decrease when a model gets better).
24
+ If you want to track a different kind of score, it is recommended to use `score_controller` parameter
21
25
22
26
More formal definition of "sort of decreasing": if we divide a curve into two parts like so:
23
27
46
50
| right part |
47
51
48
52
then the right part is no higher than 5% of global minimum
49
- (you can change 5% if you like by adjusting `fraction_threshold`
50
- in `is_score_out_of_control` function)
53
+ (you can change 5% if you like by adjusting `fraction_threshold` parameter)
51
54
52
- If score_to_track is None, then `ControllerAgent` will never stop
55
+ If ` score_to_track` is None and `score_controller` is None, then `ControllerAgent` will never stop
53
56
(useful for e.g. decaying coefficients)
54
-
57
+ fraction_threshold: float
58
+ Threshold to control a score by 'sort of decreasing' metric
59
+ score_controller: BaseScoreController
60
+ Custom score controller
61
+ In case of 'sort of decreasing' is not proper to control score, you are able to create custom Score Controller
62
+ inherited from `BaseScoreController`.
55
63
tau_converter: str or callable
56
64
Notably, def-style functions and lambda functions are allowed
57
65
If it is function, then it should accept four arguments:
109
117
that way agent will continue operating even outside this `RegularizationControllerCube`
110
118
""" # noqa: W291
111
119
112
- from .base_cube import BaseCube
113
- from ..rel_toolbox_lite import count_vocab_size , handle_regularizer
114
-
115
- import numexpr as ne
116
120
import warnings
117
- from dill .source import getsource
118
121
from copy import deepcopy
122
+ from dataclasses import dataclass
123
+ from typing import List , Optional
124
+
125
+ import numexpr as ne
119
126
import numpy as np
127
+ from dill .source import getsource
120
128
129
+ from .base_cube import BaseCube
130
+ from ..rel_toolbox_lite import count_vocab_size , handle_regularizer
121
131
122
132
W_HALT_CONTROL = "Process of dynamically changing tau was stopped at {} iteration"
123
133
W_MAX_ITERS = "Maximum number of iterations is exceeded; turning off"
124
134
125
135
126
- def is_score_out_of_control (model , score_name , fraction_threshold = 0.05 ):
127
- """
128
- Returns True if score isn't 'sort of decreasing' anymore.
136
+ @dataclass
137
+ class OutOfControlAnswer :
138
+ answer : bool
139
+ error_message : Optional [str ] = None
129
140
130
- See docstring for RegularizationControllerCube for details
131
141
132
- Parameters
133
- ----------
134
- model : TopicModel
135
- score_name : str or None
136
- fraction_threshold : float
142
+ class BaseScoreController :
143
+ def __init__ (self , score_name ):
144
+ self .score_name = score_name
137
145
138
- Returns
139
- -------
140
- bool
146
+ def get_score_values ( self , model ):
147
+ if self . score_name not in model . scores : # case of None is handled here as well
148
+ return None
141
149
150
+ vals = model .scores [self .score_name ]
151
+ if len (vals ) == 0 :
152
+ return None
153
+
154
+ return vals
155
+
156
+ def __call__ (self , model ):
157
+ values = self .get_score_values (model )
158
+
159
+ if values is None :
160
+ return False
161
+
162
+ try :
163
+ out_of_control_result = self .is_out_of_control (values )
164
+ except Exception as ex :
165
+ message = (f"An error occured while controlling { self .score_name } . Message: { ex } . Score values: { values } " )
166
+ raise ValueError (message )
167
+
168
+ if out_of_control_result .error_message is not None :
169
+ warnings .warn (out_of_control_result .error_message )
170
+
171
+ return out_of_control_result .answer
172
+
173
+ def is_out_of_control (self , values : List [float ]) -> OutOfControlAnswer :
174
+ raise NotImplementedError
175
+
176
+
177
+ class PerplexityScoreController (BaseScoreController ):
142
178
"""
179
+ Controller is proper to control the Perplexity score. For others, please ensure for yourself.
180
+ """
181
+ DEFAULT_FRACTION_THRESHOLD = 0.05
182
+
183
+ def __init__ (self , score_name , fraction_threshold = DEFAULT_FRACTION_THRESHOLD ):
184
+ super ().__init__ (score_name )
185
+ self .fraction_threshold = fraction_threshold
143
186
144
- if score_name not in model . scores : # case of None is handled here as well
145
- return False
187
+ def is_out_of_control ( self , values : List [ float ]):
188
+ idxmin = np . argmin ( values )
146
189
147
- vals = model .scores [score_name ]
148
- if len (vals ) == 0 :
149
- return False
190
+ if idxmin == len (values ): # score is monotonically decreasing
191
+ return False
150
192
151
- idxmin = np .argmin (vals )
193
+ right_maxval = max (values [idxmin :])
194
+ minval = values [idxmin ]
152
195
153
- if idxmin == len (vals ): # score is monotonically decreasing
154
- return False
155
- maxval = max (vals [idxmin :])
156
- minval = vals [idxmin ]
157
- answer = ((maxval - minval )/ abs (minval ) - 1.0 ) > fraction_threshold
158
- if answer :
159
- msg = (f"Score { score_name } is too high: during training the value { maxval } "
160
- f" passed a treshold of { (1 + fraction_threshold ) * minval } "
161
- f" (estimate is based on { idxmin } iteration)" )
162
- warnings .warn (msg )
163
- return answer
196
+ if minval <= 0 :
197
+ err_message = f"""Score { self .score_name } has min_value = { minval } which is <= 0.
198
+ This control scheme is using to control scores acting like Perplexity.
199
+ Ensure you control the Perplexity score or write your own controller"""
200
+ raise ValueError (err_message )
201
+
202
+ answer = (right_maxval - minval ) / minval > self .fraction_threshold
203
+
204
+ if answer :
205
+ message = (f"Score { self .score_name } is too high! Right max value: { right_maxval } , min value: { minval } " )
206
+ return OutOfControlAnswer (answer = answer , error_message = message )
207
+
208
+ return OutOfControlAnswer (answer = answer )
209
+
210
+
211
+ class ControllerAgentException (Exception ): pass
164
212
165
213
166
214
class ControllerAgent :
@@ -172,8 +220,10 @@ class ControllerAgent:
172
220
Each agent is described by:
173
221
174
222
* reg_name: the name of regularizer having `tau` which needs to be changed
175
- * score_to_track: score providing control of the callback execution
176
223
* tau_converter: function or string describing how to get new `tau` from old `tau`
224
+ * score_to_track: score name providing control of the callback execution
225
+ * fraction_threshold: threshold to control score_to_track
226
+ * score_controller: custom score controller providing control of the callback execution
177
227
* local_dict: dictionary containing values of several variables,
178
228
most notably, `user_value`
179
229
* is_working:
@@ -183,31 +233,64 @@ class ControllerAgent:
183
233
184
234
See top-level docstring for details.
185
235
"""
186
- def __init__ (self , reg_name , score_to_track , tau_converter , max_iters , local_dict = None ):
236
+
237
+ def __init__ (self , reg_name , tau_converter , max_iters , score_to_track = None , fraction_threshold = None ,
238
+ score_controller = None , local_dict = None ):
187
239
"""
188
240
189
241
Parameters
190
242
----------
191
243
reg_name : str
192
- score_to_track : str, list of str or None
193
244
tau_converter : callable or str
194
- local_dict : dict
195
245
max_iters : int or float
196
246
Agent will stop changing tau after `max_iters` iterations
197
247
`max_iters` could be `float("NaN")` and `float("inf")` values:
198
248
that way agent will continue operating even outside this `RegularizationControllerCube`
249
+ score_to_track : str, list of str or None
250
+ Name of score to track.
251
+ Please, use this definition to track only scores of type PerplexityScore.
252
+ In other cases we recommend you to write you own ScoreController
253
+ fraction_threshold : float, list of float of the same length as score_to_track or None
254
+ Uses to define threshold to control PerplexityScore
255
+ Default value is 0.05
256
+ score_controller : BaseScoreController, list of BaseScoreController or None
257
+ local_dict : dict
199
258
"""
200
259
if local_dict is None :
201
260
local_dict = dict ()
202
261
203
262
self .reg_name = reg_name
204
263
self .tau_converter = tau_converter
264
+
265
+ self .score_controllers = []
205
266
if isinstance (score_to_track , list ):
206
- self .score_to_track = score_to_track
267
+ if fraction_threshold is None :
268
+ controller_params = [(name , PerplexityScoreController .DEFAULT_FRACTION_THRESHOLD ) for name in
269
+ score_to_track ]
270
+ elif isinstance (fraction_threshold , list ) and len (score_to_track ) == len (fraction_threshold ):
271
+ controller_params = list (zip (score_to_track , fraction_threshold ))
272
+ else :
273
+ err_message = """Length of score_to_track and fraction_threshold must be same.
274
+ Otherwise fraction_threshold must be None"""
275
+ raise ControllerAgentException (err_message )
276
+
277
+ self .score_controllers .append (
278
+ [PerplexityScoreController (name , threshold ) for (name , threshold ) in controller_params ])
279
+
207
280
elif isinstance (score_to_track , str ):
208
- self .score_to_track = [score_to_track ]
209
- else :
210
- self .score_to_track = []
281
+ self .score_controllers .append ([PerplexityScoreController (
282
+ score_to_track ,
283
+ fraction_threshold or PerplexityScoreController .DEFAULT_FRACTION_THRESHOLD
284
+ )])
285
+
286
+ if isinstance (score_controller , BaseScoreController ):
287
+ self .score_controllers .append (score_controller )
288
+ elif isinstance (score_controller , list ):
289
+ if not all (isinstance (score , BaseScoreController ) for score in score_controller ):
290
+ err_message = """score_controller must be of type BaseScoreController or list of BaseScoreController"""
291
+ raise ControllerAgentException (err_message )
292
+
293
+ self .score_controllers .extend (score_controller )
211
294
212
295
self .is_working = True
213
296
self .local_dict = local_dict
@@ -258,7 +341,7 @@ def invoke(self, model, cur_iter):
258
341
259
342
if self .is_working :
260
343
should_stop = any (
261
- is_score_out_of_control (model , score ) for score in self .score_to_track
344
+ score_controller (model ) for score_controller in self .score_controllers
262
345
)
263
346
if should_stop :
264
347
warnings .warn (W_HALT_CONTROL .format (len (self .tau_history )))
@@ -283,26 +366,31 @@ def __init__(self, num_iter: int, parameters,
283
366
regularizers params
284
367
each dict should contain the following fields:
285
368
("reg_name" or "regularizer"),
286
- "score_to_track" (optional),
287
369
"tau_converter",
370
+ "score_to_track" (optional),
371
+ "fraction_threshold" (optional),
372
+ "score_controller" (optional),
288
373
"user_value_grid"
289
374
See top-level docstring for details.
290
375
Examples:
291
376
292
377
>> {"regularizer": artm.regularizers.<...>,
293
- >> "score_to_track": "PerplexityScore@all",
294
378
>> "tau_converter": "prev_tau * user_value",
379
+ >> "score_to_track": "PerplexityScore@all",
380
+ >> "fraction_threshold": 0.1,
295
381
>> "user_value_grid": [0.5, 1, 2]}
296
382
297
383
298
384
-----------
299
385
300
386
>> {"reg_name": "decorrelator_for_ngramms",
301
- >> "score_to_track": None,
302
387
>> "tau_converter": (
303
388
>> lambda initial_tau, prev_tau, cur_iter, user_value:
304
389
>> initial_tau * (cur_iter % 2) + user_value
305
390
>> )
391
+ >> "score_to_track": None,
392
+ >> "fraction_threshold": None,
393
+ >> "score_controller": [PerplexityScoreController("PerplexityScore@all", 0.1)],
306
394
>> "user_value_grid": [0, 1]}
307
395
308
396
reg_search : str
0 commit comments