Skip to content

Commit c69d02a

Browse files
Add dataset handler base deep
1 parent 8185320 commit c69d02a

File tree

10 files changed

+154
-136
lines changed

10 files changed

+154
-136
lines changed

adapt/base.py

Lines changed: 121 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def _save_validation_data(self, Xs, Xt):
367367
else:
368368
self.Xs_ = Xs
369369
self.Xt_ = Xt
370-
self.src_index_ = np.arange(len(Xs))
370+
self.src_index_ = None
371371

372372

373373
def _get_target_data(self, X, y):
@@ -857,7 +857,7 @@ def __init__(self,
857857
self._self_setattr_tracking = True
858858

859859

860-
def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
860+
def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
861861
"""
862862
Fit Model. Note that ``fit`` does not reset
863863
the model but extend the training.
@@ -867,7 +867,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
867867
X : array or Tensor
868868
Source input data.
869869
870-
y : array or Tensor
870+
y : array or Tensor (default=None)
871871
Source output data.
872872
873873
Xt : array (default=None)
@@ -896,7 +896,18 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
896896
if not hasattr(self, "_is_fitted"):
897897
self._is_fitted = True
898898
self._initialize_networks()
899-
self._initialize_weights(X.shape[1:])
899+
if isinstance(X, tf.data.Dataset):
900+
first_elem = next(iter(X))
901+
if (not isinstance(first_elem, tuple) or
902+
not len(first_elem)==2):
903+
raise ValueError("When first argument is a dataset. "
904+
"It should return (x, y) tuples.")
905+
else:
906+
shape = first_elem[0].shape
907+
else:
908+
shape = X.shape[1:]
909+
print(shape)
910+
self._initialize_weights(shape)
900911

901912
# 2. Get Fit params
902913
fit_params = self._filter_params(super().fit, fit_params)
@@ -909,65 +920,96 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
909920
validation_split = fit_params.pop("validation_split", 0.)
910921
validation_batch_size = fit_params.pop("validation_batch_size", batch_size)
911922

912-
# 3. Prepare dataset
923+
# 3. Prepare datasets
913924

914925
### 3.1 Source
915926
if not isinstance(X, tf.data.Dataset):
916927
check_arrays(X, y)
917928
if len(y.shape) <= 1:
918929
y = y.reshape(-1, 1)
919930

931+
# Single source
932+
if domains is None:
933+
self.n_sources_ = 1
934+
935+
dataset_Xs = tf.data.Dataset.from_tensor_slices(X)
936+
dataset_ys = tf.data.Dataset.from_tensor_slices(y)
937+
938+
# Multisource
939+
else:
940+
domains = self._check_domains(domains)
941+
self.n_sources_ = int(np.max(domains)+1)
942+
943+
sizes = [np.sum(domains==dom)
944+
for dom in range(self.n_sources_)]
945+
946+
max_size = np.max(sizes)
947+
repeats = np.ceil(max_size/sizes)
948+
949+
dataset_Xs = tf.data.Dataset.zip(tuple(
950+
tf.data.Dataset.from_tensor_slices(X[domains==dom]).repeat(repeats[dom])
951+
for dom in range(self.n_sources_))
952+
)
953+
954+
dataset_ys = tf.data.Dataset.zip(tuple(
955+
tf.data.Dataset.from_tensor_slices(y[domains==dom]).repeat(repeats[dom])
956+
for dom in range(self.n_sources_))
957+
)
958+
959+
dataset_src = tf.data.Dataset.zip((dataset_Xs, dataset_ys))
960+
961+
else:
962+
dataset_src = X
963+
920964
### 3.2 Target
921965
Xt, yt = self._get_target_data(Xt, yt)
922966
if not isinstance(Xt, tf.data.Dataset):
923967
if yt is None:
924-
yt = y
925968
check_array(Xt, ensure_2d=True, allow_nd=True)
969+
dataset_tgt = tf.data.Dataset.from_tensor_slices(Xt)
970+
926971
else:
927972
check_arrays(Xt, yt)
928973

929-
if len(yt.shape) <= 1:
930-
yt = yt.reshape(-1, 1)
974+
if len(yt.shape) <= 1:
975+
yt = yt.reshape(-1, 1)
976+
977+
dataset_Xt = tf.data.Dataset.from_tensor_slices(Xt)
978+
dataset_yt = tf.data.Dataset.from_tensor_slices(yt)
979+
dataset_tgt = tf.data.Dataset.zip((dataset_Xt, dataset_yt))
980+
981+
else:
982+
dataset_tgt = Xt
931983

932984
self._save_validation_data(X, Xt)
933985

934-
### 3.3 Domains
935-
domains = fit_params.pop("domains", None)
936-
937-
if domains is None:
938-
domains = np.zeros(len(X))
939-
940-
domains = self._check_domains(domains)
941-
942-
self.n_sources_ = int(np.max(domains)+1)
943-
944-
sizes = np.array(
945-
[np.sum(domains==dom) for dom in range(self.n_sources_)]+
946-
[len(Xt)])
947-
948-
max_size = np.max(sizes)
949-
repeats = np.ceil(max_size/sizes)
950-
951-
# Split if validation_split
952-
# if validation_data is None and validation_split>0.:
953-
# frac = int(len(dataset)*validation_split)
954-
# validation_data = dataset.take(frac)
955-
# dataset = dataset.skip(frac)
956-
986+
# 4. Get validation data
987+
validation_data = self._check_validation_data(validation_data,
988+
validation_batch_size,
989+
shuffle)
957990

958-
dataset_X = tf.data.Dataset.zip(tuple(
959-
tf.data.Dataset.from_tensor_slices(X[domains==dom]).repeat(repeats[dom])
960-
for dom in range(self.n_sources_))+
961-
(tf.data.Dataset.from_tensor_slices(Xt).repeat(repeats[-1]),)
962-
)
963-
964-
dataset_y = tf.data.Dataset.zip(tuple(
965-
tf.data.Dataset.from_tensor_slices(y[domains==dom]).repeat(repeats[dom])
966-
for dom in range(self.n_sources_))+
967-
(tf.data.Dataset.from_tensor_slices(yt).repeat(repeats[-1]),)
968-
)
991+
if validation_data is None and validation_split>0.:
992+
if shuffle:
993+
dataset_src = dataset_src.shuffle(buffer_size=1024)
994+
frac = int(len(dataset_src)*validation_split)
995+
validation_data = dataset_src.take(frac)
996+
dataset_src = dataset_src.skip(frac)
997+
validation_data = validation_data.batch(batch_size)
998+
999+
# 5. Set datasets
1000+
try:
1001+
max_size = max(len(dataset_src), len(dataset_tgt))
1002+
repeat_src = np.ceil(max_size/len(dataset_src))
1003+
repeat_tgt = np.ceil(max_size/len(dataset_tgt))
1004+
1005+
dataset_src = dataset_src.repeat(repeat_src)
1006+
dataset_tgt = dataset_tgt.repeat(repeat_tgt)
1007+
1008+
self.total_steps_ = float(np.ceil(max_size/batch_size)*epochs)
1009+
except:
1010+
pass
9691011

970-
# 4. Pretraining
1012+
# 5. Pretraining
9711013
if not hasattr(self, "pretrain_"):
9721014
if not hasattr(self, "pretrain"):
9731015
self.pretrain_ = False
@@ -994,32 +1036,22 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
9941036
pre_epochs = prefit_params.pop("epochs", epochs)
9951037
pre_batch_size = prefit_params.pop("batch_size", batch_size)
9961038
pre_shuffle = prefit_params.pop("shuffle", shuffle)
1039+
prefit_params.pop("validation_data", None)
1040+
prefit_params.pop("validation_split", None)
1041+
prefit_params.pop("validation_batch_size", None)
9971042

9981043
if pre_shuffle:
999-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).shuffle(buffer_size=1024).batch(pre_batch_size)
1044+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).shuffle(buffer_size=1024).batch(pre_batch_size)
10001045
else:
1001-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).batch(pre_batch_size)
1046+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).batch(pre_batch_size)
10021047

1003-
hist = super().fit(dataset, epochs=pre_epochs, verbose=pre_verbose, **prefit_params)
1048+
hist = super().fit(dataset, validation_data=validation_data,
1049+
epochs=pre_epochs, verbose=pre_verbose, **prefit_params)
10041050

10051051
for k, v in hist.history.items():
10061052
self.pretrain_history_[k] = self.pretrain_history_.get(k, []) + v
10071053

10081054
self._initialize_pretain_networks()
1009-
1010-
# 5. Define validation Set
1011-
if isinstance(validation_data, tuple):
1012-
X_val = validation_data[0]
1013-
y_val = validation_data[1]
1014-
1015-
validation_data = tf.data.Dataset.zip(
1016-
(tf.data.Dataset.from_tensor_slices(X_val),
1017-
tf.data.Dataset.from_tensor_slices(y_val))
1018-
)
1019-
if shuffle:
1020-
validation_data = validation_data.shuffle(buffer_size=1024).batch(batch_size)
1021-
else:
1022-
validation_data = validation_data.batch(batch_size)
10231055

10241056
# 6. Training
10251057
if (not self._is_compiled) or (self.pretrain_):
@@ -1029,13 +1061,12 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
10291061
self.history_ = {}
10301062

10311063
if shuffle:
1032-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).shuffle(buffer_size=1024).batch(batch_size)
1064+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).shuffle(buffer_size=1024).batch(batch_size)
10331065
else:
1034-
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).batch(batch_size)
1066+
dataset = tf.data.Dataset.zip((dataset_src, dataset_tgt)).batch(batch_size)
10351067

10361068
self.pretrain_ = False
10371069
self.steps_ = tf.Variable(0.)
1038-
self.total_steps_ = float(np.ceil(max_size/batch_size)*epochs)
10391070

10401071
hist = super().fit(dataset, validation_data=validation_data, **fit_params)
10411072

@@ -1227,10 +1258,6 @@ def train_step(self, data):
12271258
# Unpack the data.
12281259
Xs, Xt, ys, yt = self._unpack_data(data)
12291260

1230-
# Single source
1231-
Xs = Xs[0]
1232-
ys = ys[0]
1233-
12341261
# Run forward pass.
12351262
with tf.GradientTape() as tape:
12361263
y_pred = self(Xs, training=True)
@@ -1418,6 +1445,22 @@ def score_estimator(self, X, y, sample_weight=None):
14181445
if isinstance(score, (tuple, list)):
14191446
score = score[0]
14201447
return score
1448+
1449+
1450+
def _check_validation_data(self, validation_data, batch_size, shuffle):
1451+
if isinstance(validation_data, tuple):
1452+
X_val = validation_data[0]
1453+
y_val = validation_data[1]
1454+
1455+
validation_data = tf.data.Dataset.zip(
1456+
(tf.data.Dataset.from_tensor_slices(X_val),
1457+
tf.data.Dataset.from_tensor_slices(y_val))
1458+
)
1459+
if shuffle:
1460+
validation_data = validation_data.shuffle(buffer_size=1024).batch(batch_size)
1461+
else:
1462+
validation_data = validation_data.batch(batch_size)
1463+
return validation_data
14211464

14221465

14231466
def _get_legal_params(self, params):
@@ -1467,13 +1510,17 @@ def _initialize_weights(self, shape_X):
14671510

14681511

14691512
def _unpack_data(self, data):
1470-
data_X = data[0]
1471-
data_y = data[1]
1472-
Xs = data_X[:-1]
1473-
Xt = data_X[-1]
1474-
ys = data_y[:-1]
1475-
yt = data_y[-1]
1476-
return Xs, Xt, ys, ys
1513+
data_src = data[0]
1514+
data_tgt = data[1]
1515+
Xs = data_src[0]
1516+
ys = data_src[1]
1517+
if isinstance(data_tgt, tuple):
1518+
Xt = data_tgt[0]
1519+
yt = data_tgt[1]
1520+
return Xs, Xt, ys, yt
1521+
else:
1522+
Xt = data_tgt
1523+
return Xs, Xt, ys, None
14771524

14781525

14791526
def _get_disc_metrics(self, ys_disc, yt_disc):

adapt/feature_based/_adda.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,6 @@ def pretrain_step(self, data):
168168
# Unpack the data.
169169
Xs, Xt, ys, yt = self._unpack_data(data)
170170

171-
# Single source
172-
Xs = Xs[0]
173-
ys = ys[0]
174-
175171
# loss
176172
with tf.GradientTape() as tape:
177173
# Forward pass
@@ -208,11 +204,7 @@ def train_step(self, data):
208204
else:
209205
# Unpack the data.
210206
Xs, Xt, ys, yt = self._unpack_data(data)
211-
212-
# Single source
213-
Xs = Xs[0]
214-
ys = ys[0]
215-
207+
216208
# loss
217209
with tf.GradientTape() as task_tape, tf.GradientTape() as enc_tape, tf.GradientTape() as disc_tape:
218210
# Forward pass

adapt/feature_based/_cdan.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,6 @@ def __init__(self,
169169
def train_step(self, data):
170170
# Unpack the data.
171171
Xs, Xt, ys, yt = self._unpack_data(data)
172-
173-
# Single source
174-
Xs = Xs[0]
175-
ys = ys[0]
176172

177173
# loss
178174
with tf.GradientTape() as task_tape, tf.GradientTape() as enc_tape, tf.GradientTape() as disc_tape:

adapt/feature_based/_dann.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,6 @@ def train_step(self, data):
138138
# Unpack the data.
139139
Xs, Xt, ys, yt = self._unpack_data(data)
140140

141-
# Single source
142-
Xs = Xs[0]
143-
ys = ys[0]
144-
145141
if self.lambda_ is None:
146142
_is_lambda_None = 1.
147143
lambda_ = 0.

adapt/feature_based/_deepcoral.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,6 @@ def __init__(self,
133133
def train_step(self, data):
134134
# Unpack the data.
135135
Xs, Xt, ys, yt = self._unpack_data(data)
136-
137-
# Single source
138-
Xs = Xs[0]
139-
ys = ys[0]
140136

141137
if self.match_mean:
142138
_match_mean = 1.

adapt/feature_based/_mcd.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ def __init__(self,
8989
def pretrain_step(self, data):
9090
# Unpack the data.
9191
Xs, Xt, ys, yt = self._unpack_data(data)
92-
93-
# Single source
94-
Xs = Xs[0]
95-
ys = ys[0]
9692

9793
# loss
9894
with tf.GradientTape() as tape:
@@ -136,10 +132,6 @@ def train_step(self, data):
136132
else:
137133
# Unpack the data.
138134
Xs, Xt, ys, yt = self._unpack_data(data)
139-
140-
# Single source
141-
Xs = Xs[0]
142-
ys = ys[0]
143135

144136
# loss
145137
with tf.GradientTape() as task_tape, tf.GradientTape() as enc_tape, tf.GradientTape() as disc_tape:

adapt/feature_based/_mdd.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,6 @@ def train_step(self, data):
8888
# Unpack the data.
8989
Xs, Xt, ys, yt = self._unpack_data(data)
9090

91-
# Single source
92-
Xs = Xs[0]
93-
ys = ys[0]
94-
9591
# If crossentropy take argmax of preds
9692
if hasattr(self.task_loss_, "name"):
9793
name = self.task_loss_.name

adapt/feature_based/_wdgrl.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,6 @@ def __init__(self,
131131
def train_step(self, data):
132132
# Unpack the data.
133133
Xs, Xt, ys, yt = self._unpack_data(data)
134-
135-
# Single source
136-
Xs = Xs[0]
137-
ys = ys[0]
138134

139135
# loss
140136
with tf.GradientTape() as task_tape, tf.GradientTape() as enc_tape, tf.GradientTape() as disc_tape:

0 commit comments

Comments
 (0)