Skip to content

Commit 058c9d0

Browse files
2 parents ed1b107 + c69d02a commit 058c9d0

File tree

10 files changed

+181
-135
lines changed

10 files changed

+181
-135
lines changed

adapt/base.py

Lines changed: 148 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def _save_validation_data(self, Xs, Xt):
367367
else:
368368
self.Xs_ = Xs
369369
self.Xt_ = Xt
370-
self.src_index_ = np.arange(len(Xs))
370+
self.src_index_ = None
371371

372372

373373
def _get_target_data(self, X, y):
@@ -857,7 +857,7 @@ def __init__(self,
857857
self._self_setattr_tracking = True
858858

859859

860-
def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
860+
def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
861861
"""
862862
Fit Model. Note that ``fit`` does not reset
863863
the model but extend the training.
@@ -867,7 +867,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
867867
X : array or Tensor
868868
Source input data.
869869
870-
y : array or Tensor
870+
y : array or Tensor (default=None)
871871
Source output data.
872872
873873
Xt : array (default=None)
@@ -896,64 +896,120 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
896896
if not hasattr(self, "_is_fitted"):
897897
self._is_fitted = True
898898
self._initialize_networks()
899-
self._initialize_weights(X.shape[1:])
899+
if isinstance(X, tf.data.Dataset):
900+
first_elem = next(iter(X))
901+
if (not isinstance(first_elem, tuple) or
902+
not len(first_elem)==2):
903+
raise ValueError("When first argument is a dataset. "
904+
"It should return (x, y) tuples.")
905+
else:
906+
shape = first_elem[0].shape
907+
else:
908+
shape = X.shape[1:]
909+
print(shape)
910+
self._initialize_weights(shape)
911+
912+
# 2. Get Fit params
913+
fit_params = self._filter_params(super().fit, fit_params)
914+
915+
verbose = fit_params.get("verbose", 1)
916+
epochs = fit_params.get("epochs", 1)
917+
batch_size = fit_params.pop("batch_size", 32)
918+
shuffle = fit_params.pop("shuffle", True)
919+
validation_data = fit_params.pop("validation_data", None)
920+
validation_split = fit_params.pop("validation_split", 0.)
921+
validation_batch_size = fit_params.pop("validation_batch_size", batch_size)
922+
923+
# 3. Prepare datasets
900924

901-
# 2. Prepare dataset
925+
### 3.1 Source
926+
if not isinstance(X, tf.data.Dataset):
927+
check_arrays(X, y)
928+
if len(y.shape) <= 1:
929+
y = y.reshape(-1, 1)
930+
931+
# Single source
932+
if domains is None:
933+
self.n_sources_ = 1
934+
935+
dataset_Xs = tf.data.Dataset.from_tensor_slices(X)
936+
dataset_ys = tf.data.Dataset.from_tensor_slices(y)
937+
938+
# Multisource
939+
else:
940+
domains = self._check_domains(domains)
941+
self.n_sources_ = int(np.max(domains)+1)
942+
943+
sizes = [np.sum(domains==dom)
944+
for dom in range(self.n_sources_)]
945+
946+
max_size = np.max(sizes)
947+
repeats = np.ceil(max_size/sizes)
948+
949+
dataset_Xs = tf.data.Dataset.zip(tuple(
950+
tf.data.Dataset.from_tensor_slices(X[domains==dom]).repeat(repeats[dom])
951+
for dom in range(self.n_sources_))
952+
)
953+
954+
dataset_ys = tf.data.Dataset.zip(tuple(
955+
tf.data.Dataset.from_tensor_slices(y[domains==dom]).repeat(repeats[dom])
956+
for dom in range(self.n_sources_))
957+
)
958+
959+
dataset_src = tf.data.Dataset.zip((dataset_Xs, dataset_ys))
960+
961+
else:
962+
dataset_src = X
963+
964+
### 3.2 Target
902965
Xt, yt = self._get_target_data(Xt, yt)
966+
if not isinstance(Xt, tf.data.Dataset):
967+
if yt is None:
968+
check_array(Xt, ensure_2d=True, allow_nd=True)
969+
dataset_tgt = tf.data.Dataset.from_tensor_slices(Xt)
903970

904-
check_arrays(X, y)
905-
if len(y.shape) <= 1:
906-
y = y.reshape(-1, 1)
971+
else:
972+
check_arrays(Xt, yt)
973+
974+
if len(yt.shape) <= 1:
975+
yt = yt.reshape(-1, 1)
976+
977+
dataset_Xt = tf.data.Dataset.from_tensor_slices(Xt)
978+
dataset_yt = tf.data.Dataset.from_tensor_slices(yt)
979+
dataset_tgt = tf.data.Dataset.zip((dataset_Xt, dataset_yt))
907980

908-
if yt is None:
909-
yt = y
910-
check_array(Xt, ensure_2d=True, allow_nd=True)
911981
else:
912-
check_arrays(Xt, yt)
913-
914-
if len(yt.shape) <= 1:
915-
yt = yt.reshape(-1, 1)
982+
dataset_tgt = Xt
916983

917984
self._save_validation_data(X, Xt)
918985

919-
domains = fit_params.pop("domains", None)
920-
921-
if domains is None:
922-
domains = np.zeros(len(X))
923-
924-
domains = self._check_domains(domains)
925-
926-
self.n_sources_ = int(np.max(domains)+1)
927-
928-
sizes = np.array(
929-
[np.sum(domains==dom) for dom in range(self.n_sources_)]+
930-
[len(Xt)])
931-
932-
max_size = np.max(sizes)
933-
repeats = np.ceil(max_size/sizes)
934-
935-
dataset_X = tf.data.Dataset.zip(tuple(
936-
tf.data.Dataset.from_tensor_slices(X[domains==dom]).repeat(repeats[dom])
937-
for dom in range(self.n_sources_))+
938-
(tf.data.Dataset.from_tensor_slices(Xt).repeat(repeats[-1]),)
939-
)
940-
941-
dataset_y = tf.data.Dataset.zip(tuple(
942-
tf.data.Dataset.from_tensor_slices(y[domains==dom]).repeat(repeats[dom])
943-
for dom in range(self.n_sources_))+
944-
(tf.data.Dataset.from_tensor_slices(yt).repeat(repeats[-1]),)
945-
)
946-
947-
948-
# 3. Get Fit params
949-
fit_params = self._filter_params(super().fit, fit_params)
950-
951-
verbose = fit_params.get("verbose", 1)
952-
epochs = fit_params.get("epochs", 1)
953-
batch_size = fit_params.pop("batch_size", 32)
954-
shuffle = fit_params.pop("shuffle", True)
986+
# 4. Get validation data
987+
validation_data = self._check_validation_data(validation_data,
988+
validation_batch_size,
989+
shuffle)
990+
991+
if validation_data is None and validation_split>0.:
992+
if shuffle:
993+
dataset_src = dataset_src.shuffle(buffer_size=1024)
994+
frac = int(len(dataset_src)*validation_split)
995+
validation_data = dataset_src.take(frac)
996+
dataset_src = dataset_src.skip(frac)
997+
validation_data = validation_data.batch(batch_size)
998+
999+
# 5. Set datasets
1000+
try:
1001+
max_size = max(len(dataset_src), len(dataset_tgt))
1002+
repeat_src = np.ceil(max_size/len(dataset_src))
1003+
repeat_tgt = np.ceil(max_size/len(dataset_tgt))
1004+
1005+
dataset_src = dataset_src.repeat(repeat_src)
1006+
dataset_tgt = dataset_tgt.repeat(repeat_tgt)
1007+
1008+
self.total_steps_ = float(np.ceil(max_size/batch_size)*epochs)
1009+
except:
1010+
pass
9551011

956-
# 4. Pretraining
1012+
# 5. Pretraining
9571013
if not hasattr(self, "pretrain_"):
9581014
if not hasattr(self, "pretrain"):
9591015
self.pretrain_ = False
@@ -980,36 +1036,39 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
9801036
pre_epochs = prefit_params.pop("epochs", epochs)
9811037
pre_batch_size = prefit_params.pop("batch_size", batch_size)
9821038
pre_shuffle = prefit_params.pop("shuffle", shuffle)
1039+
prefit_params.pop("validation_data", None)
1040+
prefit_params.pop("validation_split", None)
1041+
prefit_params.pop("validation_batch_size", None)
9831042

9841043
if pre_shuffle:
985-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).shuffle(buffer_size=1024).batch(pre_batch_size)
1044+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).shuffle(buffer_size=1024).batch(pre_batch_size)
9861045
else:
987-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).batch(pre_batch_size)
1046+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).batch(pre_batch_size)
9881047

989-
hist = super().fit(dataset, epochs=pre_epochs, verbose=pre_verbose, **prefit_params)
1048+
hist = super().fit(dataset, validation_data=validation_data,
1049+
epochs=pre_epochs, verbose=pre_verbose, **prefit_params)
9901050

9911051
for k, v in hist.history.items():
9921052
self.pretrain_history_[k] = self.pretrain_history_.get(k, []) + v
9931053

9941054
self._initialize_pretain_networks()
995-
996-
# 5. Training
1055+
1056+
# 6. Training
9971057
if (not self._is_compiled) or (self.pretrain_):
9981058
self.compile()
9991059

10001060
if not hasattr(self, "history_"):
10011061
self.history_ = {}
10021062

10031063
if shuffle:
1004-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).shuffle(buffer_size=1024).batch(batch_size)
1064+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).shuffle(buffer_size=1024).batch(batch_size)
10051065
else:
1006-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).batch(batch_size)
1007-
1066+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).batch(batch_size)
1067+
10081068
self.pretrain_ = False
10091069
self.steps_ = tf.Variable(0.)
1010-
self.total_steps_ = float(np.ceil(max_size/batch_size)*epochs)
10111070

1012-
hist = super().fit(dataset, **fit_params)
1071+
hist = super().fit(dataset, validation_data=validation_data, **fit_params)
10131072

10141073
for k, v in hist.history.items():
10151074
self.history_[k] = self.history_.get(k, []) + v
@@ -1199,10 +1258,6 @@ def train_step(self, data):
11991258
# Unpack the data.
12001259
Xs, Xt, ys, yt = self._unpack_data(data)
12011260

1202-
# Single source
1203-
Xs = Xs[0]
1204-
ys = ys[0]
1205-
12061261
# Run forward pass.
12071262
with tf.GradientTape() as tape:
12081263
y_pred = self(Xs, training=True)
@@ -1390,6 +1445,22 @@ def score_estimator(self, X, y, sample_weight=None):
13901445
if isinstance(score, (tuple, list)):
13911446
score = score[0]
13921447
return score
1448+
1449+
1450+
def _check_validation_data(self, validation_data, batch_size, shuffle):
1451+
if isinstance(validation_data, tuple):
1452+
X_val = validation_data[0]
1453+
y_val = validation_data[1]
1454+
1455+
validation_data = tf.data.Dataset.zip(
1456+
(tf.data.Dataset.from_tensor_slices(X_val),
1457+
tf.data.Dataset.from_tensor_slices(y_val))
1458+
)
1459+
if shuffle:
1460+
validation_data = validation_data.shuffle(buffer_size=1024).batch(batch_size)
1461+
else:
1462+
validation_data = validation_data.batch(batch_size)
1463+
return validation_data
13931464

13941465

13951466
def _get_legal_params(self, params):
@@ -1439,13 +1510,17 @@ def _initialize_weights(self, shape_X):
14391510

14401511

14411512
def _unpack_data(self, data):
1442-
data_X = data[0]
1443-
data_y = data[1]
1444-
Xs = data_X[:-1]
1445-
Xt = data_X[-1]
1446-
ys = data_y[:-1]
1447-
yt = data_y[-1]
1448-
return Xs, Xt, ys, ys
1513+
data_src = data[0]
1514+
data_tgt = data[1]
1515+
Xs = data_src[0]
1516+
ys = data_src[1]
1517+
if isinstance(data_tgt, tuple):
1518+
Xt = data_tgt[0]
1519+
yt = data_tgt[1]
1520+
return Xs, Xt, ys, yt
1521+
else:
1522+
Xt = data_tgt
1523+
return Xs, Xt, ys, None
14491524

14501525

14511526
def _get_disc_metrics(self, ys_disc, yt_disc):

adapt/feature_based/_adda.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,6 @@ def pretrain_step(self, data):
168168
# Unpack the data.
169169
Xs, Xt, ys, yt = self._unpack_data(data)
170170

171-
# Single source
172-
Xs = Xs[0]
173-
ys = ys[0]
174-
175171
# loss
176172
with tf.GradientTape() as tape:
177173
# Forward pass
@@ -208,11 +204,7 @@ def train_step(self, data):
208204
else:
209205
# Unpack the data.
210206
Xs, Xt, ys, yt = self._unpack_data(data)
211-
212-
# Single source
213-
Xs = Xs[0]
214-
ys = ys[0]
215-
207+
216208
# loss
217209
with tf.GradientTape() as task_tape, tf.GradientTape() as enc_tape, tf.GradientTape() as disc_tape:
218210
# Forward pass

adapt/feature_based/_cdan.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,6 @@ def __init__(self,
169169
def train_step(self, data):
170170
# Unpack the data.
171171
Xs, Xt, ys, yt = self._unpack_data(data)
172-
173-
# Single source
174-
Xs = Xs[0]
175-
ys = ys[0]
176172

177173
# loss
178174
with tf.GradientTape() as task_tape, tf.GradientTape() as enc_tape, tf.GradientTape() as disc_tape:

adapt/feature_based/_dann.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,6 @@ def train_step(self, data):
138138
# Unpack the data.
139139
Xs, Xt, ys, yt = self._unpack_data(data)
140140

141-
# Single source
142-
Xs = Xs[0]
143-
ys = ys[0]
144-
145141
if self.lambda_ is None:
146142
_is_lambda_None = 1.
147143
lambda_ = 0.

adapt/feature_based/_deepcoral.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,6 @@ def __init__(self,
133133
def train_step(self, data):
134134
# Unpack the data.
135135
Xs, Xt, ys, yt = self._unpack_data(data)
136-
137-
# Single source
138-
Xs = Xs[0]
139-
ys = ys[0]
140136

141137
if self.match_mean:
142138
_match_mean = 1.

adapt/feature_based/_mcd.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ def __init__(self,
8989
def pretrain_step(self, data):
9090
# Unpack the data.
9191
Xs, Xt, ys, yt = self._unpack_data(data)
92-
93-
# Single source
94-
Xs = Xs[0]
95-
ys = ys[0]
9692

9793
# loss
9894
with tf.GradientTape() as tape:
@@ -136,10 +132,6 @@ def train_step(self, data):
136132
else:
137133
# Unpack the data.
138134
Xs, Xt, ys, yt = self._unpack_data(data)
139-
140-
# Single source
141-
Xs = Xs[0]
142-
ys = ys[0]
143135

144136

145137
for _ in range(4):

adapt/feature_based/_mdd.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,6 @@ def train_step(self, data):
8888
# Unpack the data.
8989
Xs, Xt, ys, yt = self._unpack_data(data)
9090

91-
# Single source
92-
Xs = Xs[0]
93-
ys = ys[0]
94-
9591
# If crossentropy take argmax of preds
9692
if hasattr(self.task_loss_, "name"):
9793
name = self.task_loss_.name

0 commit comments

Comments
 (0)