Skip to content

Commit 7d40345

Browse files
Update Transfer Tree + Fix cloning for parameter based
1 parent 049e239 commit 7d40345

20 files changed

+595
-262
lines changed

adapt/parameter_based/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,10 @@
44

55
from ._regular import RegularTransferLR, RegularTransferLC, RegularTransferNN
66
from ._finetuning import FineTuning
7+
from ._transfer_tree import TransferTreeClassifier
78

8-
__all__ = ["RegularTransferLR", "RegularTransferLC", "RegularTransferNN", "FineTuning"]
9+
__all__ = ["RegularTransferLR",
10+
"RegularTransferLC",
11+
"RegularTransferNN",
12+
"FineTuning",
13+
"TransferTreeClassifier"]

adapt/parameter_based/_finetuning.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import tensorflow as tf
22
from adapt.base import BaseAdaptDeep, make_insert_doc
3+
from adapt.utils import check_fitted_network
34

45

56
@make_insert_doc(["encoder", "task"], supervised=True)
@@ -67,6 +68,9 @@ def __init__(self,
6768
random_state=None,
6869
**params):
6970

71+
encoder = check_fitted_network(encoder)
72+
task = check_fitted_network(task)
73+
7074
names = self._get_param_names()
7175
kwargs = {k: v for k, v in locals().items() if k in names}
7276
kwargs.update(params)

adapt/parameter_based/_regular.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44

55
import numpy as np
6-
from sklearn.exceptions import NotFittedError
76
from sklearn.preprocessing import LabelBinarizer
87
from scipy.sparse.linalg import lsqr
98
import tensorflow as tf
@@ -14,7 +13,8 @@
1413
from adapt.utils import (check_arrays,
1514
set_random_seed,
1615
check_estimator,
17-
check_network)
16+
check_network,
17+
check_fitted_estimator)
1818

1919

2020
@make_insert_doc(supervised=True)
@@ -101,9 +101,12 @@ def __init__(self,
101101
**params):
102102

103103
if not hasattr(estimator, "coef_"):
104-
raise NotFittedError("`estimator` argument has no ``coef_`` attribute, "
105-
"please call `fit` on `estimator` or use "
106-
"another estimator.")
104+
raise ValueError("`estimator` argument has no ``coef_`` attribute, "
105+
"please call `fit` on `estimator` or use "
106+
"another estimator as `LinearRegression` or "
107+
"`RidgeClassifier`.")
108+
109+
estimator = check_fitted_estimator(estimator)
107110

108111
names = self._get_param_names()
109112
kwargs = {k: v for k, v in locals().items() if k in names}
@@ -137,7 +140,7 @@ def fit(self, Xt=None, yt=None, **fit_params):
137140
self.estimator_ = check_estimator(self.estimator,
138141
copy=self.copy,
139142
force_copy=True)
140-
143+
141144
if self.estimator_.fit_intercept:
142145
intercept_ = np.reshape(
143146
self.estimator_.intercept_,

0 commit comments

Comments
 (0)