Skip to content

Commit 7de00c4

Browse files
right percentage distance + tests
1 parent aa4658e commit 7de00c4

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

topicnet/cooking_machine/cubes/controller_cube.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,29 @@
109109
that way agent will continue operating even outside this `RegularizationControllerCube`
110110
""" # noqa: W291
111111

112-
from .base_cube import BaseCube
113-
from ..rel_toolbox_lite import count_vocab_size, handle_regularizer
114-
115-
import numexpr as ne
116112
import warnings
117-
from dill.source import getsource
118113
from copy import deepcopy
114+
115+
import numexpr as ne
119116
import numpy as np
117+
from dill.source import getsource
120118

119+
from .base_cube import BaseCube
120+
from ..rel_toolbox_lite import count_vocab_size, handle_regularizer
121121

122122
W_HALT_CONTROL = "Process of dynamically changing tau was stopped at {} iteration"
123123
W_MAX_ITERS = "Maximum number of iterations is exceeded; turning off"
124124

125125

126+
def get_two_values_diff(min_val: float, max_val: float):
127+
if min_val == 0:
128+
return max_val
129+
130+
answer = (max_val-min_val)/min_val
131+
132+
return answer
133+
134+
126135
def is_score_out_of_control(model, score_name, fraction_threshold=0.05):
127136
"""
128137
Returns True if score isn't 'sort of decreasing' anymore.
@@ -148,13 +157,17 @@ def is_score_out_of_control(model, score_name, fraction_threshold=0.05):
148157
if len(vals) == 0:
149158
return False
150159

160+
161+
151162
idxmin = np.argmin(vals)
152163

153164
if idxmin == len(vals): # score is monotonically decreasing
154165
return False
155166
maxval = max(vals[idxmin:])
156167
minval = vals[idxmin]
157-
answer = ((maxval - minval)/abs(minval) - 1.0) > fraction_threshold
168+
diff = get_two_values_diff(minval, maxval)
169+
answer = diff > fraction_threshold
170+
158171
if answer:
159172
msg = (f"Score {score_name} is too high: during training the value {maxval}"
160173
f" passed a treshold of {(1 + fraction_threshold) * minval}"
@@ -183,6 +196,7 @@ class ControllerAgent:
183196
184197
See top-level docstring for details.
185198
"""
199+
186200
def __init__(self, reg_name, score_to_track, tau_converter, max_iters, local_dict=None):
187201
"""
188202
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pytest
2+
3+
from topicnet.cooking_machine.cubes.controller_cube import get_two_values_diff
4+
5+
data_2values = [
6+
(1, 2, 1),
7+
(0, 2, 2),
8+
(100, 115, 0.15)
9+
]
10+
11+
12+
@pytest.mark.parametrize('min_val, max_val, true_diff', data_2values)
13+
def test_get_two_values_diff(min_val, max_val, true_diff):
14+
diff = get_two_values_diff(min_val, max_val)
15+
16+
assert diff == true_diff

0 commit comments

Comments
 (0)