Skip to content

Commit 801789e

Browse files
Merge pull request #19 from antoinedemathelin/master
fix: Fix bugs Deep methods
2 parents f0e927d + 3581099 commit 801789e

File tree

16 files changed

+500
-244
lines changed

16 files changed

+500
-244
lines changed

adapt/base.py

Lines changed: 158 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def _save_validation_data(self, Xs, Xt):
367367
else:
368368
self.Xs_ = Xs
369369
self.Xt_ = Xt
370-
self.src_index_ = np.arange(len(Xs))
370+
self.src_index_ = None
371371

372372

373373
def _get_target_data(self, X, y):
@@ -458,7 +458,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
458458
if yt is not None:
459459
Xt, yt = check_arrays(Xt, yt)
460460
else:
461-
Xt = check_array(Xt)
461+
Xt = check_array(Xt, ensure_2d=True, allow_nd=True)
462462
set_random_seed(self.random_state)
463463

464464
self._save_validation_data(X, Xt)
@@ -857,7 +857,7 @@ def __init__(self,
857857
self._self_setattr_tracking = True
858858

859859

860-
def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
860+
def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
861861
"""
862862
Fit Model. Note that ``fit`` does not reset
863863
the model but extend the training.
@@ -867,7 +867,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
867867
X : array or Tensor
868868
Source input data.
869869
870-
y : array or Tensor
870+
y : array or Tensor (default=None)
871871
Source output data.
872872
873873
Xt : array (default=None)
@@ -889,71 +889,126 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
889889
Returns
890890
-------
891891
self : returns an instance of self
892-
"""
892+
"""
893893
set_random_seed(self.random_state)
894894

895895
# 1. Initialize networks
896896
if not hasattr(self, "_is_fitted"):
897897
self._is_fitted = True
898898
self._initialize_networks()
899-
self._initialize_weights(X.shape[1:])
899+
if isinstance(X, tf.data.Dataset):
900+
first_elem = next(iter(X))
901+
if (not isinstance(first_elem, tuple) or
902+
not len(first_elem)==2):
903+
raise ValueError("When first argument is a dataset. "
904+
"It should return (x, y) tuples.")
905+
else:
906+
shape = first_elem[0].shape
907+
else:
908+
shape = X.shape[1:]
909+
self._initialize_weights(shape)
910+
911+
# 2. Get Fit params
912+
fit_params = self._filter_params(super().fit, fit_params)
900913

901-
# 2. Prepare dataset
914+
verbose = fit_params.get("verbose", 1)
915+
epochs = fit_params.get("epochs", 1)
916+
batch_size = fit_params.pop("batch_size", 32)
917+
shuffle = fit_params.pop("shuffle", True)
918+
validation_data = fit_params.pop("validation_data", None)
919+
validation_split = fit_params.pop("validation_split", 0.)
920+
validation_batch_size = fit_params.pop("validation_batch_size", batch_size)
921+
922+
# 3. Prepare datasets
923+
924+
### 3.1 Source
925+
if not isinstance(X, tf.data.Dataset):
926+
check_arrays(X, y)
927+
if len(y.shape) <= 1:
928+
y = y.reshape(-1, 1)
929+
930+
# Single source
931+
if domains is None:
932+
self.n_sources_ = 1
933+
934+
dataset_Xs = tf.data.Dataset.from_tensor_slices(X)
935+
dataset_ys = tf.data.Dataset.from_tensor_slices(y)
936+
937+
# Multisource
938+
else:
939+
domains = self._check_domains(domains)
940+
self.n_sources_ = int(np.max(domains)+1)
941+
942+
sizes = [np.sum(domains==dom)
943+
for dom in range(self.n_sources_)]
944+
945+
max_size = np.max(sizes)
946+
repeats = np.ceil(max_size/sizes)
947+
948+
dataset_Xs = tf.data.Dataset.zip(tuple(
949+
tf.data.Dataset.from_tensor_slices(X[domains==dom]).repeat(repeats[dom])
950+
for dom in range(self.n_sources_))
951+
)
952+
953+
dataset_ys = tf.data.Dataset.zip(tuple(
954+
tf.data.Dataset.from_tensor_slices(y[domains==dom]).repeat(repeats[dom])
955+
for dom in range(self.n_sources_))
956+
)
957+
958+
dataset_src = tf.data.Dataset.zip((dataset_Xs, dataset_ys))
959+
960+
else:
961+
dataset_src = X
962+
963+
### 3.2 Target
902964
Xt, yt = self._get_target_data(Xt, yt)
965+
if not isinstance(Xt, tf.data.Dataset):
966+
if yt is None:
967+
check_array(Xt, ensure_2d=True, allow_nd=True)
968+
dataset_tgt = tf.data.Dataset.from_tensor_slices(Xt)
903969

904-
check_arrays(X, y)
905-
if len(y.shape) <= 1:
906-
y = y.reshape(-1, 1)
970+
else:
971+
check_arrays(Xt, yt)
972+
973+
if len(yt.shape) <= 1:
974+
yt = yt.reshape(-1, 1)
975+
976+
dataset_Xt = tf.data.Dataset.from_tensor_slices(Xt)
977+
dataset_yt = tf.data.Dataset.from_tensor_slices(yt)
978+
dataset_tgt = tf.data.Dataset.zip((dataset_Xt, dataset_yt))
907979

908-
if yt is None:
909-
yt = y
910-
check_array(Xt)
911980
else:
912-
check_arrays(Xt, yt)
913-
914-
if len(yt.shape) <= 1:
915-
yt = yt.reshape(-1, 1)
981+
dataset_tgt = Xt
916982

917983
self._save_validation_data(X, Xt)
918984

919-
domains = fit_params.pop("domains", None)
920-
921-
if domains is None:
922-
domains = np.zeros(len(X))
923-
924-
domains = self._check_domains(domains)
925-
926-
self.n_sources_ = int(np.max(domains)+1)
927-
928-
sizes = np.array(
929-
[np.sum(domains==dom) for dom in range(self.n_sources_)]+
930-
[len(Xt)])
931-
932-
max_size = np.max(sizes)
933-
repeats = np.ceil(max_size/sizes)
934-
935-
dataset_X = tf.data.Dataset.zip(tuple(
936-
tf.data.Dataset.from_tensor_slices(X[domains==dom]).repeat(repeats[dom])
937-
for dom in range(self.n_sources_))+
938-
(tf.data.Dataset.from_tensor_slices(Xt).repeat(repeats[-1]),)
939-
)
940-
941-
dataset_y = tf.data.Dataset.zip(tuple(
942-
tf.data.Dataset.from_tensor_slices(y[domains==dom]).repeat(repeats[dom])
943-
for dom in range(self.n_sources_))+
944-
(tf.data.Dataset.from_tensor_slices(yt).repeat(repeats[-1]),)
945-
)
946-
947-
948-
# 3. Get Fit params
949-
fit_params = self._filter_params(super().fit, fit_params)
950-
951-
verbose = fit_params.get("verbose", 1)
952-
epochs = fit_params.get("epochs", 1)
953-
batch_size = fit_params.pop("batch_size", 32)
954-
shuffle = fit_params.pop("shuffle", True)
985+
# 4. Get validation data
986+
# validation_data = self._check_validation_data(validation_data,
987+
# validation_batch_size,
988+
# shuffle)
989+
990+
if validation_data is None and validation_split>0.:
991+
if shuffle:
992+
dataset_src = dataset_src.shuffle(buffer_size=1024)
993+
frac = int(len(dataset_src)*validation_split)
994+
validation_data = dataset_src.take(frac)
995+
dataset_src = dataset_src.skip(frac)
996+
validation_data = validation_data.batch(batch_size)
997+
998+
# 5. Set datasets
999+
try:
1000+
max_size = max(len(dataset_src), len(dataset_tgt))
1001+
repeat_src = np.ceil(max_size/len(dataset_src))
1002+
repeat_tgt = np.ceil(max_size/len(dataset_tgt))
1003+
1004+
dataset_src = dataset_src.repeat(repeat_src)
1005+
dataset_tgt = dataset_tgt.repeat(repeat_tgt)
1006+
1007+
self.total_steps_ = float(np.ceil(max_size/batch_size)*epochs)
1008+
except:
1009+
pass
9551010

956-
# 4. Pretraining
1011+
# 5. Pretraining
9571012
if not hasattr(self, "pretrain_"):
9581013
if not hasattr(self, "pretrain"):
9591014
self.pretrain_ = False
@@ -980,36 +1035,39 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
9801035
pre_epochs = prefit_params.pop("epochs", epochs)
9811036
pre_batch_size = prefit_params.pop("batch_size", batch_size)
9821037
pre_shuffle = prefit_params.pop("shuffle", shuffle)
1038+
prefit_params.pop("validation_data", None)
1039+
prefit_params.pop("validation_split", None)
1040+
prefit_params.pop("validation_batch_size", None)
9831041

9841042
if pre_shuffle:
985-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).shuffle(buffer_size=1024).batch(pre_batch_size)
1043+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).shuffle(buffer_size=1024).batch(pre_batch_size)
9861044
else:
987-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).batch(pre_batch_size)
1045+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).batch(pre_batch_size)
9881046

989-
hist = super().fit(dataset, epochs=pre_epochs, verbose=pre_verbose, **prefit_params)
1047+
hist = super().fit(dataset, validation_data=validation_data,
1048+
epochs=pre_epochs, verbose=pre_verbose, **prefit_params)
9901049

9911050
for k, v in hist.history.items():
9921051
self.pretrain_history_[k] = self.pretrain_history_.get(k, []) + v
9931052

9941053
self._initialize_pretain_networks()
995-
996-
# 5. Training
1054+
1055+
# 6. Compile
9971056
if (not self._is_compiled) or (self.pretrain_):
9981057
self.compile()
9991058

10001059
if not hasattr(self, "history_"):
10011060
self.history_ = {}
10021061

1062+
# .7 Training
10031063
if shuffle:
1004-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).shuffle(buffer_size=1024).batch(batch_size)
1064+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).shuffle(buffer_size=1024).batch(batch_size)
10051065
else:
1006-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).batch(batch_size)
1007-
1066+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).batch(batch_size)
1067+
10081068
self.pretrain_ = False
1009-
self.steps_ = tf.Variable(0.)
1010-
self.total_steps_ = float(np.ceil(max_size/batch_size)*epochs)
10111069

1012-
hist = super().fit(dataset, **fit_params)
1070+
hist = super().fit(dataset, validation_data=validation_data, **fit_params)
10131071

10141072
for k, v in hist.history.items():
10151073
self.history_[k] = self.history_.get(k, []) + v
@@ -1188,6 +1246,12 @@ def compile(self,
11881246
super().compile(
11891247
**compile_params
11901248
)
1249+
1250+
# Set optimizer for encoder and discriminator
1251+
if not hasattr(self, "optimizer_enc"):
1252+
self.optimizer_enc = self.optimizer
1253+
if not hasattr(self, "optimizer_disc"):
1254+
self.optimizer_disc = self.optimizer
11911255

11921256

11931257
def call(self, inputs):
@@ -1199,10 +1263,6 @@ def train_step(self, data):
11991263
# Unpack the data.
12001264
Xs, Xt, ys, yt = self._unpack_data(data)
12011265

1202-
# Single source
1203-
Xs = Xs[0]
1204-
ys = ys[0]
1205-
12061266
# Run forward pass.
12071267
with tf.GradientTape() as tape:
12081268
y_pred = self(Xs, training=True)
@@ -1376,7 +1436,7 @@ def score_estimator(self, X, y, sample_weight=None):
13761436
score : float
13771437
Score.
13781438
"""
1379-
if np.prod(X.shape) <= 10**8:
1439+
if hasattr(X, "shape") and np.prod(X.shape) <= 10**8:
13801440
score = self.evaluate(
13811441
X, y,
13821442
sample_weight=sample_weight,
@@ -1390,6 +1450,22 @@ def score_estimator(self, X, y, sample_weight=None):
13901450
if isinstance(score, (tuple, list)):
13911451
score = score[0]
13921452
return score
1453+
1454+
1455+
# def _check_validation_data(self, validation_data, batch_size, shuffle):
1456+
# if isinstance(validation_data, tuple):
1457+
# X_val = validation_data[0]
1458+
# y_val = validation_data[1]
1459+
1460+
# validation_data = tf.data.Dataset.zip(
1461+
# (tf.data.Dataset.from_tensor_slices(X_val),
1462+
# tf.data.Dataset.from_tensor_slices(y_val))
1463+
# )
1464+
# if shuffle:
1465+
# validation_data = validation_data.shuffle(buffer_size=1024).batch(batch_size)
1466+
# else:
1467+
# validation_data = validation_data.batch(batch_size)
1468+
# return validation_data
13931469

13941470

13951471
def _get_legal_params(self, params):
@@ -1405,7 +1481,7 @@ def _get_legal_params(self, params):
14051481
if (optimizer is not None) and (not isinstance(optimizer, str)):
14061482
legal_params_fct.append(optimizer.__init__)
14071483

1408-
legal_params = ["domain", "val_sample_size"]
1484+
legal_params = ["domain", "val_sample_size", "optimizer_enc", "optimizer_disc"]
14091485
for func in legal_params_fct:
14101486
args = [
14111487
p.name
@@ -1439,13 +1515,17 @@ def _initialize_weights(self, shape_X):
14391515

14401516

14411517
def _unpack_data(self, data):
1442-
data_X = data[0]
1443-
data_y = data[1]
1444-
Xs = data_X[:-1]
1445-
Xt = data_X[-1]
1446-
ys = data_y[:-1]
1447-
yt = data_y[-1]
1448-
return Xs, Xt, ys, ys
1518+
data_src = data[0]
1519+
data_tgt = data[1]
1520+
Xs = data_src[0]
1521+
ys = data_src[1]
1522+
if isinstance(data_tgt, tuple):
1523+
Xt = data_tgt[0]
1524+
yt = data_tgt[1]
1525+
return Xs, Xt, ys, yt
1526+
else:
1527+
Xt = data_tgt
1528+
return Xs, Xt, ys, None
14491529

14501530

14511531
def _get_disc_metrics(self, ys_disc, yt_disc):

0 commit comments

Comments
 (0)