109
109
that way agent will continue operating even outside this `RegularizationControllerCube`
110
110
""" # noqa: W291
111
111
112
- from .base_cube import BaseCube
113
- from ..rel_toolbox_lite import count_vocab_size , handle_regularizer
114
-
115
- import numexpr as ne
116
112
import warnings
117
- from dill .source import getsource
118
113
from copy import deepcopy
114
+
115
+ import numexpr as ne
119
116
import numpy as np
117
+ from dill .source import getsource
120
118
119
+ from .base_cube import BaseCube
120
+ from ..rel_toolbox_lite import count_vocab_size , handle_regularizer
121
121
122
122
W_HALT_CONTROL = "Process of dynamically changing tau was stopped at {} iteration"
123
123
W_MAX_ITERS = "Maximum number of iterations is exceeded; turning off"
124
124
125
125
126
+ def get_two_values_diff (min_val : float , max_val : float ):
127
+ if min_val == 0 :
128
+ return max_val
129
+
130
+ answer = (max_val - min_val )/ min_val
131
+
132
+ return answer
133
+
134
+
126
135
def is_score_out_of_control (model , score_name , fraction_threshold = 0.05 ):
127
136
"""
128
137
Returns True if score isn't 'sort of decreasing' anymore.
@@ -148,13 +157,17 @@ def is_score_out_of_control(model, score_name, fraction_threshold=0.05):
148
157
if len (vals ) == 0 :
149
158
return False
150
159
160
+
161
+
151
162
idxmin = np .argmin (vals )
152
163
153
164
if idxmin == len (vals ): # score is monotonically decreasing
154
165
return False
155
166
maxval = max (vals [idxmin :])
156
167
minval = vals [idxmin ]
157
- answer = ((maxval - minval )/ abs (minval ) - 1.0 ) > fraction_threshold
168
+ diff = get_two_values_diff (minval , maxval )
169
+ answer = diff > fraction_threshold
170
+
158
171
if answer :
159
172
msg = (f"Score { score_name } is too high: during training the value { maxval } "
160
173
f" passed a treshold of { (1 + fraction_threshold ) * minval } "
@@ -183,6 +196,7 @@ class ControllerAgent:
183
196
184
197
See top-level docstring for details.
185
198
"""
199
+
186
200
def __init__ (self , reg_name , score_to_track , tau_converter , max_iters , local_dict = None ):
187
201
"""
188
202
0 commit comments