Skip to content

Commit 0c7632a

Browse files
Fix bug inspect signature
1 parent e665d78 commit 0c7632a

File tree

9 files changed

+730
-292
lines changed

9 files changed

+730
-292
lines changed

adapt/base.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,11 @@ def score_adapt(self, Xs, Xt, src_index=None):
265265
if src_index is None:
266266
src_index = np.arange(len(Xs))
267267
if hasattr(self, "transform"):
268-
args = inspect.getfullargspec(self.transform).args
268+
args = [
269+
p.name
270+
for p in inspect.signature(self.transform).parameters.values()
271+
if p.name != "self" and p.kind != p.VAR_KEYWORD
272+
]
269273
if "domain" in args:
270274
Xt = self.transform(Xt, domain="tgt")
271275
Xs = self.transform(Xs, domain="src")
@@ -333,7 +337,11 @@ def _get_param_names(self):
333337

334338
def _filter_params(self, func, override={}, prefix=""):
335339
kwargs = {}
336-
args = inspect.getfullargspec(func).args
340+
args = [
341+
p.name
342+
for p in inspect.signature(func).parameters.values()
343+
if p.name != "self" and p.kind != p.VAR_KEYWORD
344+
]
337345
for key, value in self.__dict__.items():
338346
new_key = key.replace(prefix+"__", "")
339347
if new_key in args and prefix in key:
@@ -575,7 +583,12 @@ def fit_estimator(self, X, y, sample_weight=None,
575583

576584
fit_params = self._filter_params(self.estimator_.fit, fit_params)
577585

578-
if "sample_weight" in inspect.getfullargspec(self.estimator_.fit).args:
586+
fit_args = [
587+
p.name
588+
for p in inspect.signature(self.estimator_.fit).parameters.values()
589+
if p.name != "self" and p.kind != p.VAR_KEYWORD
590+
]
591+
if "sample_weight" in fit_args:
579592
sample_weight = check_sample_weight(sample_weight, X)
580593
with warnings.catch_warnings():
581594
warnings.simplefilter("ignore")
@@ -718,7 +731,11 @@ def _get_legal_params(self, params):
718731

719732
legal_params = ["domain", "val_sample_size"]
720733
for func in legal_params_fct:
721-
args = list(inspect.getfullargspec(func).args)
734+
args = [
735+
p.name
736+
for p in inspect.signature(func).parameters.values()
737+
if p.name != "self" and p.kind != p.VAR_KEYWORD
738+
]
722739
legal_params = legal_params + args
723740

724741
# Add kernel params for kernel based algorithm
@@ -1190,7 +1207,11 @@ def _get_legal_params(self, params):
11901207

11911208
legal_params = ["domain", "val_sample_size"]
11921209
for func in legal_params_fct:
1193-
args = list(inspect.getfullargspec(func).args)
1210+
args = [
1211+
p.name
1212+
for p in inspect.signature(func).parameters.values()
1213+
if p.name != "self" and p.kind != p.VAR_KEYWORD
1214+
]
11941215
legal_params = legal_params + args
11951216

11961217
if "pretrain" in legal_params:
@@ -1200,7 +1221,11 @@ def _get_legal_params(self, params):
12001221
legal_params_fct.append(params["pretrain__optimizer"].__init__)
12011222

12021223
for func in legal_params_fct:
1203-
args = list(inspect.getfullargspec(func).args)
1224+
args = [
1225+
p.name
1226+
for p in inspect.signature(func).parameters.values()
1227+
if p.name != "self" and p.kind != p.VAR_KEYWORD
1228+
]
12041229
legal_params = legal_params + ["pretrain__"+name for name in args]
12051230
return legal_params
12061231

adapt/feature_based/_adda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def _initialize_weights(self, shape_X):
288288
self.encoder_(np.zeros((1,) + shape_X))
289289

290290
# Set same weights to encoder_src
291-
self.encoder_src_ = check_network(self.encoder,
291+
self.encoder_src_ = check_network(self.encoder_,
292292
copy=True,
293293
name="encoder_src")
294294

adapt/feature_based/_cdan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,6 @@ def train_step(self, data):
248248
disc_loss += sum(self.discriminator_.losses)
249249
enc_loss += sum(self.encoder_.losses)
250250

251-
print(task_loss.shape, enc_loss.shape, disc_loss.shape)
252-
253251
# Compute gradients
254252
trainable_vars_task = self.task_.trainable_variables
255253
trainable_vars_enc = self.encoder_.trainable_variables
@@ -295,8 +293,10 @@ def _initialize_weights(self, shape_X):
295293
self.max_features])
296294
self._random_enc = tf.random.normal([Xs_enc.get_shape()[1],
297295
self.max_features])
296+
self.discriminator_(np.zeros((1, self.max_features)))
298297
else:
299298
self.is_overloaded_ = False
299+
self.discriminator_(np.zeros((1, Xs_enc.get_shape()[1] * ys_pred.get_shape()[1])))
300300

301301

302302
def _initialize_networks(self):
@@ -318,6 +318,7 @@ def _initialize_networks(self):
318318
self.discriminator_ = check_network(self.discriminator,
319319
copy=self.copy,
320320
name="discriminator")
321+
321322

322323

323324
# def _initialize_networks(self, shape_Xt):

adapt/feature_based/_mcd.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import tensorflow as tf
77

88
from adapt.base import BaseAdaptDeep, make_insert_doc
9-
from adapt.utils import check_network
9+
from adapt.utils import check_network, get_default_encoder, get_default_task
1010

1111
EPS = np.finfo(np.float32).eps
1212

@@ -27,10 +27,10 @@ class MCD(BaseAdaptDeep):
2727
2828
Parameters
2929
----------
30-
pretrain_steps : int (default=None)
31-
Specify the number of pretraining of source encoder
32-
and task networks. If `None` the number of pretrain
33-
steps is equal to half the total number of training steps.
30+
pretrain : bool (default=True)
31+
Weither to pretrain the networks or not.
32+
If True, the three networks are fitted on source
33+
labeled data.
3434
3535
Attributes
3636
----------

adapt/feature_based/_mdd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import tensorflow as tf
77

88
from adapt.base import BaseAdaptDeep, make_insert_doc
9-
from adapt.utils import check_network
9+
from adapt.utils import check_network, get_default_encoder, get_default_task
1010

1111
EPS = np.finfo(np.float32).eps
1212

adapt/feature_based/_wdgrl.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,6 @@ class WDGRL(BaseAdaptDeep):
7171
7272
discriminator_ : tensorflow Model
7373
discriminator network.
74-
75-
model_ : tensorflow Model
76-
Fitted model: the union of ``encoder_``,
77-
``task_`` and ``discriminator_`` networks.
7874
7975
history_ : dict
8076
history of the losses and metrics across the epochs.

adapt/instance_based/_wann.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ class WANN(BaseAdaptDeep):
2828
between the reweighted source and target distributions: the Y-discrepancy
2929
3030
Parameters
31-
----------
31+
----------
32+
pretrain : bool (default=True)
33+
Weither to perform pretraining of the ``weighter``
34+
network or not. If True, the ``weighter`` is
35+
pretrained in order to predict 1 for each source.
36+
3237
C : float (default=1.)
3338
Clipping constant for the weighting networks
3439
regularization. Low value of ``C`` produce smoother
@@ -40,6 +45,7 @@ def __init__(self,
4045
weighter=None,
4146
Xt=None,
4247
yt=None,
48+
pretrain=True,
4349
C=1.,
4450
verbose=1,
4551
copy=True,

src_docs/examples/Two_moons.ipynb

Lines changed: 675 additions & 265 deletions
Large diffs are not rendered by default.

tests/test_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.compose import TransformedTargetRegressor
1717
from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin
1818
from sklearn.tree._tree import Tree
19-
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor
19+
# from tensorflow.keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor
2020
from tensorflow.keras import Model, Sequential
2121
from tensorflow.keras.layers import Input, Dense, Flatten, Reshape
2222
from tensorflow.python.keras.engine.input_layer import InputLayer
@@ -28,7 +28,7 @@ def is_equal_estimator(v1, v2):
2828
assert type(v2) == type(v1)
2929
if isinstance(v1, np.ndarray):
3030
assert np.array_equal(v1, v2)
31-
elif isinstance(v1, (BaseEstimator, KerasClassifier, KerasRegressor)):
31+
elif isinstance(v1, BaseEstimator): # KerasClassifier, KerasRegressor
3232
assert is_equal_estimator(v1.__dict__, v2.__dict__)
3333
elif isinstance(v1, Model):
3434
assert is_equal_estimator(v1.get_config(),
@@ -198,7 +198,7 @@ def test_check_network_high_dataset():
198198
TransformedTargetRegressor(regressor=Ridge(alpha=25), transformer=StandardScaler()),
199199
MultiOutputRegressor(Ridge(alpha=0.3)),
200200
make_pipeline(StandardScaler(), Ridge(alpha=0.2)),
201-
KerasClassifier(_get_model_Sequential, input_shape=(1,)),
201+
# KerasClassifier(_get_model_Sequential, input_shape=(1,)),
202202
CustomEstimator()
203203
]
204204

@@ -212,10 +212,10 @@ def test_check_estimator_estimators(est):
212212
else:
213213
est.fit(np.linspace(0, 1, 10).reshape(-1, 1),
214214
(np.linspace(0, 1, 10)<0.5).astype(float))
215-
if isinstance(est, KerasClassifier):
216-
new_est = check_estimator(est, copy=False)
217-
else:
218-
new_est = check_estimator(est, copy=True, force_copy=True)
215+
# if isinstance(est, KerasClassifier):
216+
# new_est = check_estimator(est, copy=False)
217+
# else:
218+
new_est = check_estimator(est, copy=True, force_copy=True)
219219
assert is_equal_estimator(est, new_est)
220220

221221

0 commit comments

Comments
 (0)