18
18
We assume that if that metric is 'sort of decreasing', then everything is OK
19
19
and we are allowed to change tau coefficient further; otherwise we revert back
20
20
to the last "safe" value and stop
21
+
22
+ 'sort of decreasing' perform best with PerplexityScore.
21
23
22
24
More formal definition of "sort of decreasing": if we divide a curve into two parts like so:
23
25
46
48
| right part |
47
49
48
50
then the right part is no higher than 5% of global minimum
49
- (you can change 5% if you like by adjusting `fraction_threshold`
50
- in `is_score_out_of_control` function)
51
+ (you can change 5% if you like by adjusting `fraction_threshold` parameter)
51
52
52
- If score_to_track is None, then `ControllerAgent` will never stop
53
+ If score_to_track is None and score_controller is None , then `ControllerAgent` will never stop
53
54
(useful for e.g. decaying coefficients)
55
+ fraction_threshold: float
56
+ Threshold to control a score by 'sort of decreasing' metric
57
+ score_controller: ScoreControllerBase
58
+ Custom score controller
59
+ In case of 'sort of decreasing' is not proper to control score, you are able to create custom Score Controller
60
+ inherited from ScoreControllerBase. New controller must contain function `is_score_out_of_control`
61
+ which takes TopicModel as input and return answer of type OutOfControlAnswer
54
62
55
63
tau_converter: str or callable
56
64
Notably, def-style functions and lambda functions are allowed
111
119
112
120
import warnings
113
121
from copy import deepcopy
122
+ from dataclasses import dataclass
123
+ from typing import List , Optional
114
124
115
125
import numexpr as ne
116
126
import numpy as np
123
133
W_MAX_ITERS = "Maximum number of iterations is exceeded; turning off"
124
134
125
135
126
- def get_two_values_diff (min_val : float , max_val : float ):
127
- if min_val == 0 :
128
- return max_val
136
+ @dataclass
137
+ class OutOfControlAnswer :
138
+ answer : Optional [bool ]
139
+ error_message : Optional [str ] = None
129
140
130
- answer = (max_val - min_val )/ min_val
131
141
132
- return answer
142
+ class ScoreControllerBase :
143
+ def __init__ (self , score_name ):
144
+ self .score_name = score_name
133
145
146
+ def get_vals (self , model ):
147
+ if self .score_name not in model .scores : # case of None is handled here as well
148
+ return None
134
149
135
- def is_score_out_of_control (model , score_name , fraction_threshold = 0.05 ):
136
- """
137
- Returns True if score isn't 'sort of decreasing' anymore.
150
+ vals = model .scores [self .score_name ]
151
+ if len (vals ) == 0 :
152
+ return None
153
+
154
+ return vals
155
+
156
+ def __call__ (self , model ):
157
+ values = self .get_vals (model )
158
+
159
+ if values is None :
160
+ return False
138
161
139
- See docstring for RegularizationControllerCube for details
162
+ try :
163
+ out_of_control_result = self .is_out_of_control (values )
164
+ except Exception as ex :
165
+ message = (f"An error occured while controlling { self .score_name } . Message: { ex } . Score values: { values } " )
166
+ warnings .warn (message )
167
+ return True
140
168
141
- Parameters
142
- ----------
143
- model : TopicModel
144
- score_name : str or None
145
- fraction_threshold : float
169
+ if out_of_control_result .error_message is not None :
170
+ warnings .warn (out_of_control_result .error_message )
146
171
147
- Returns
148
- -------
149
- bool
172
+ return out_of_control_result .answer
150
173
174
+ def is_out_of_control (self , values : List [float ]) -> OutOfControlAnswer :
175
+ raise NotImplementedError
176
+
177
+
178
+ class ScoreControllerPerplexity (ScoreControllerBase ):
179
+ """
180
+ Controller is properto control the Perplexity score. For others, please ensure for yourself.
151
181
"""
182
+ DEFAULT_FRACTION_THRESHOLD = 0.05
152
183
153
- if score_name not in model .scores : # case of None is handled here as well
154
- return False
184
+ def __init__ (self , score_name , fraction_threshold = DEFAULT_FRACTION_THRESHOLD ):
185
+ super ().__init__ (score_name )
186
+ self .fraction_threshold = fraction_threshold
155
187
156
- vals = model .scores [score_name ]
157
- if len (vals ) == 0 :
158
- return False
188
+ def is_out_of_control (self , values : List [float ]):
189
+ idxmin = np .argmin (values )
159
190
191
+ if idxmin == len (values ): # score is monotonically decreasing
192
+ return False
160
193
194
+ right_maxval = max (values [idxmin :])
195
+ minval = values [idxmin ]
161
196
162
- idxmin = np .argmin (vals )
197
+ if minval <= 0 :
198
+ message = f"""Score { self .score_name } has min_value = { minval } which is <=0.
199
+ This control scheme is using to control scores acting like Perplexity.
200
+ Ensure you control the Perplexity score or write your own controller"""
201
+ return OutOfControlAnswer (answer = True , error_message = message )
163
202
164
- if idxmin == len (vals ): # score is monotonically decreasing
165
- return False
166
- maxval = max (vals [idxmin :])
167
- minval = vals [idxmin ]
168
- diff = get_two_values_diff (minval , maxval )
169
- answer = diff > fraction_threshold
203
+ answer = (right_maxval - minval ) / minval > self .fraction_threshold
170
204
171
- if answer :
172
- msg = (f"Score { score_name } is too high: during training the value { maxval } "
173
- f" passed a treshold of { (1 + fraction_threshold ) * minval } "
174
- f" (estimate is based on { idxmin } iteration)" )
175
- warnings .warn (msg )
176
- return answer
205
+ if answer :
206
+ message = (f"Score { self .score_name } is too high! Right max value: { right_maxval } , min value: { minval } " )
207
+ return OutOfControlAnswer (answer = answer , error_message = message )
208
+
209
+ return OutOfControlAnswer (answer = answer )
210
+
211
+
212
+ class ControllerAgentException (Exception ): pass
177
213
178
214
179
215
class ControllerAgent :
@@ -185,8 +221,10 @@ class ControllerAgent:
185
221
Each agent is described by:
186
222
187
223
* reg_name: the name of regularizer having `tau` which needs to be changed
188
- * score_to_track: score providing control of the callback execution
189
224
* tau_converter: function or string describing how to get new `tau` from old `tau`
225
+ * score_to_track: score name providing control of the callback execution
226
+ * fraction_threshold: threshold to control score_to_track
227
+ * score_controller: custom score controller providing control of the callback execution
190
228
* local_dict: dictionary containing values of several variables,
191
229
most notably, `user_value`
192
230
* is_working:
@@ -197,31 +235,61 @@ class ControllerAgent:
197
235
See top-level docstring for details.
198
236
"""
199
237
200
- def __init__ (self , reg_name , score_to_track , tau_converter , max_iters , local_dict = None ):
238
+ def __init__ (self , reg_name , tau_converter , max_iters , score_to_track = None , fraction_threshold = None ,
239
+ score_controller = None , local_dict = None ):
201
240
"""
202
241
203
242
Parameters
204
243
----------
205
244
reg_name : str
206
- score_to_track : str, list of str or None
207
245
tau_converter : callable or str
208
- local_dict : dict
209
246
max_iters : int or float
210
247
Agent will stop changing tau after `max_iters` iterations
211
248
`max_iters` could be `float("NaN")` and `float("inf")` values:
212
249
that way agent will continue operating even outside this `RegularizationControllerCube`
250
+ score_to_track : str, list of str or None
251
+ Name of score to track
252
+ Please, use this definition to track only scores of type PerplexityScore.
253
+ In other cases we recommend you to write you own ScoreController
254
+ fraction_threshold : float, list of float of the same length as score_to_track or None
255
+ Uses to define threshold to control PerplexityScore
256
+ Default value is 0.05
257
+ score_controller : ScoreControllerBase, list of ScoreControllerBase or None
258
+ local_dict : dict
213
259
"""
214
260
if local_dict is None :
215
261
local_dict = dict ()
216
262
217
263
self .reg_name = reg_name
218
264
self .tau_converter = tau_converter
265
+
266
+ self .score_controllers = []
219
267
if isinstance (score_to_track , list ):
220
- self .score_to_track = score_to_track
268
+ if fraction_threshold is None :
269
+ scores = [(ScoreControllerPerplexity .DEFAULT_FRACTION_THRESHOLD , name ) for name in score_to_track ]
270
+ elif isinstance (fraction_threshold , list ) and len (score_to_track ) == len (fraction_threshold ):
271
+ scores = list (zip (score_to_track , fraction_threshold ))
272
+ else :
273
+ err_message = """Length of score_to_track and fraction_threshold must be same.
274
+ Otherwise fraction_threshold must be None"""
275
+ raise ControllerAgentException (err_message )
276
+
277
+ self .score_controllers .append ([ScoreControllerPerplexity (name , threshold ) for (name , threshold ) in scores ])
278
+
221
279
elif isinstance (score_to_track , str ):
222
- self .score_to_track = [score_to_track ]
223
- else :
224
- self .score_to_track = []
280
+ self .score_controllers .append ([ScoreControllerPerplexity (
281
+ score_to_track ,
282
+ fraction_threshold or ScoreControllerPerplexity .DEFAULT_FRACTION_THRESHOLD
283
+ )])
284
+
285
+ if isinstance (score_controller , ScoreControllerBase ):
286
+ self .score_controllers .append (score_controller )
287
+ elif isinstance (score_controller , list ):
288
+ if not all (isinstance (score , ScoreControllerBase ) for score in score_controller ):
289
+ err_message = """score_controller must be of type ScoreControllerBase od list of ScoreControllerBase"""
290
+ raise ControllerAgentException (err_message )
291
+
292
+ self .score_controllers .extend (score_controller )
225
293
226
294
self .is_working = True
227
295
self .local_dict = local_dict
@@ -272,7 +340,7 @@ def invoke(self, model, cur_iter):
272
340
273
341
if self .is_working :
274
342
should_stop = any (
275
- is_score_out_of_control (model , score ) for score in self .score_to_track
343
+ score_controller (model ) for score_controller in self .score_controllers
276
344
)
277
345
if should_stop :
278
346
warnings .warn (W_HALT_CONTROL .format (len (self .tau_history )))
@@ -297,26 +365,31 @@ def __init__(self, num_iter: int, parameters,
297
365
regularizers params
298
366
each dict should contain the following fields:
299
367
("reg_name" or "regularizer"),
300
- "score_to_track" (optional),
301
368
"tau_converter",
369
+ "score_to_track" (optional),
370
+ "fraction_threshold" (optional),
371
+ "score_controller" (optional),
302
372
"user_value_grid"
303
373
See top-level docstring for details.
304
374
Examples:
305
375
306
376
>> {"regularizer": artm.regularizers.<...>,
307
- >> "score_to_track": "PerplexityScore@all",
308
377
>> "tau_converter": "prev_tau * user_value",
378
+ >> "score_to_track": "PerplexityScore@all",
379
+ >> "fraction_threshold": 0.1,
309
380
>> "user_value_grid": [0.5, 1, 2]}
310
381
311
382
312
383
-----------
313
384
314
385
>> {"reg_name": "decorrelator_for_ngramms",
315
- >> "score_to_track": None,
316
386
>> "tau_converter": (
317
387
>> lambda initial_tau, prev_tau, cur_iter, user_value:
318
388
>> initial_tau * (cur_iter % 2) + user_value
319
389
>> )
390
+ >> "score_to_track": None,
391
+ >> "fraction_threshold": None,
392
+ >> "score_controller": [ScoreControllerPerplexity("PerplexityScore@all", 0.1)],
320
393
>> "user_value_grid": [0, 1]}
321
394
322
395
reg_search : str
0 commit comments