Skip to content

Commit cae888b

Browse files
Merge pull request #17 from antoinedemathelin/master
feat: Override predict of network to gain speed
2 parents 041359b + cd735f9 commit cae888b

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

adapt/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
from tensorflow.keras.models import clone_model
1919

2020

21+
def predict(self, x, **kwargs):
22+
if (np.prod(x.shape) <= 10**8):
23+
pred = self.__call__(tf.identity(x)).numpy()
24+
else:
25+
pred = Sequential.predict(self, x, **kwargs)
26+
return pred
27+
28+
2129
def check_one_array(X, multisource=False):
2230
"""
2331
Check array and reshape 1D array in 2D array
@@ -288,6 +296,9 @@ def check_network(network, copy=True,
288296
(display_name))
289297
else:
290298
new_network = network
299+
300+
# Override the predict method to speed the prediction for small dataset
301+
new_network.predict = predict.__get__(new_network)
291302
return new_network
292303

293304

tests/test_cdan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_fit_lambda_zero():
6161
epochs=300, verbose=0)
6262
assert isinstance(model.model_, Model)
6363
assert model.history_['task_acc_s'][-1] > 0.9
64-
assert model.history_['task_acc_t'][-1] < 0.6
64+
assert model.history_['task_acc_t'][-1] < 0.7
6565

6666

6767
def test_fit_lambda_one_no_entropy():

tests/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,13 @@ def test_check_network_compile():
265265
"Please use `model.compile(optimizer, loss)`."
266266
in str(excinfo.value))
267267

268+
269+
def test_check_network_high_dataset():
270+
Xs, ys, Xt, yt = make_regression_da(100000, 1001)
271+
net = _get_model_Sequential(compiled=True)
272+
new_net = check_network(net, copy=True, compile_=True)
273+
new_net.predict(Xs)
274+
268275

269276
estimators = [
270277
Ridge(),

0 commit comments

Comments
 (0)