Skip to content

Commit 3b5aaa7

Browse files
Merge pull request #31 from antoinedemathelin/master
fix: Update transfer tree + Fix clone for parameter based methods
2 parents ea54493 + 5389aff commit 3b5aaa7

23 files changed

+606
-274
lines changed

adapt/parameter_based/decision_trees/tree_utils.py renamed to adapt/_tree_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,8 @@ def coherent_new_split(phi,th,rule):
464464
return 0,1
465465
else:
466466
return 1,0
467-
467+
468+
468469
def all_coherent_splits(rule,all_splits):
469470

470471
inds = np.zeros(all_splits.shape[0],dtype=bool)

adapt/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def normalized_frechet_distance(Xs, Xt):
413413
return x_max / Xs.shape[1]
414414

415415

416-
def j_score(Xs, Xt, max_centers=100, sigma=None):
416+
def neg_j_score(Xs, Xt, max_centers=100, sigma=None):
417417
"""
418418
Compute the negative J-score between Xs and Xt.
419419

adapt/parameter_based/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,10 @@
44

55
from ._regular import RegularTransferLR, RegularTransferLC, RegularTransferNN
66
from ._finetuning import FineTuning
7+
from ._transfer_tree import TransferTreeClassifier
78

8-
__all__ = ["RegularTransferLR", "RegularTransferLC", "RegularTransferNN", "FineTuning"]
9+
__all__ = ["RegularTransferLR",
10+
"RegularTransferLC",
11+
"RegularTransferNN",
12+
"FineTuning",
13+
"TransferTreeClassifier"]

adapt/parameter_based/_finetuning.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import tensorflow as tf
22
from adapt.base import BaseAdaptDeep, make_insert_doc
3+
from adapt.utils import check_fitted_network
34

45

56
@make_insert_doc(["encoder", "task"], supervised=True)
@@ -67,6 +68,9 @@ def __init__(self,
6768
random_state=None,
6869
**params):
6970

71+
encoder = check_fitted_network(encoder)
72+
task = check_fitted_network(task)
73+
7074
names = self._get_param_names()
7175
kwargs = {k: v for k, v in locals().items() if k in names}
7276
kwargs.update(params)

adapt/parameter_based/_regular.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44

55
import numpy as np
6-
from sklearn.exceptions import NotFittedError
76
from sklearn.preprocessing import LabelBinarizer
87
from scipy.sparse.linalg import lsqr
98
import tensorflow as tf
@@ -14,7 +13,8 @@
1413
from adapt.utils import (check_arrays,
1514
set_random_seed,
1615
check_estimator,
17-
check_network)
16+
check_network,
17+
check_fitted_estimator)
1818

1919

2020
@make_insert_doc(supervised=True)
@@ -101,9 +101,12 @@ def __init__(self,
101101
**params):
102102

103103
if not hasattr(estimator, "coef_"):
104-
raise NotFittedError("`estimator` argument has no ``coef_`` attribute, "
105-
"please call `fit` on `estimator` or use "
106-
"another estimator.")
104+
raise ValueError("`estimator` argument has no ``coef_`` attribute, "
105+
"please call `fit` on `estimator` or use "
106+
"another estimator as `LinearRegression` or "
107+
"`RidgeClassifier`.")
108+
109+
estimator = check_fitted_estimator(estimator)
107110

108111
names = self._get_param_names()
109112
kwargs = {k: v for k, v in locals().items() if k in names}
@@ -137,7 +140,7 @@ def fit(self, Xt=None, yt=None, **fit_params):
137140
self.estimator_ = check_estimator(self.estimator,
138141
copy=self.copy,
139142
force_copy=True)
140-
143+
141144
if self.estimator_.fit_intercept:
142145
intercept_ = np.reshape(
143146
self.estimator_.intercept_,

0 commit comments

Comments
 (0)