Skip to content

Commit 8185320

Browse files
Add validation data in BaseDeep
1 parent 31cc3fb commit 8185320

File tree

1 file changed

+57
-29
lines changed

1 file changed

+57
-29
lines changed

adapt/base.py

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -897,41 +897,64 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
897897
self._is_fitted = True
898898
self._initialize_networks()
899899
self._initialize_weights(X.shape[1:])
900+
901+
# 2. Get Fit params
902+
fit_params = self._filter_params(super().fit, fit_params)
900903

901-
# 2. Prepare dataset
904+
verbose = fit_params.get("verbose", 1)
905+
epochs = fit_params.get("epochs", 1)
906+
batch_size = fit_params.pop("batch_size", 32)
907+
shuffle = fit_params.pop("shuffle", True)
908+
validation_data = fit_params.pop("validation_data", None)
909+
validation_split = fit_params.pop("validation_split", 0.)
910+
validation_batch_size = fit_params.pop("validation_batch_size", batch_size)
911+
912+
# 3. Prepare dataset
913+
914+
### 3.1 Source
915+
if not isinstance(X, tf.data.Dataset):
916+
check_arrays(X, y)
917+
if len(y.shape) <= 1:
918+
y = y.reshape(-1, 1)
919+
920+
### 3.2 Target
902921
Xt, yt = self._get_target_data(Xt, yt)
903-
904-
check_arrays(X, y)
905-
if len(y.shape) <= 1:
906-
y = y.reshape(-1, 1)
922+
if not isinstance(Xt, tf.data.Dataset):
923+
if yt is None:
924+
yt = y
925+
check_array(Xt, ensure_2d=True, allow_nd=True)
926+
else:
927+
check_arrays(Xt, yt)
907928

908-
if yt is None:
909-
yt = y
910-
check_array(Xt, ensure_2d=True, allow_nd=True)
911-
else:
912-
check_arrays(Xt, yt)
913-
914-
if len(yt.shape) <= 1:
915-
yt = yt.reshape(-1, 1)
929+
if len(yt.shape) <= 1:
930+
yt = yt.reshape(-1, 1)
916931

917932
self._save_validation_data(X, Xt)
918933

934+
### 3.3 Domains
919935
domains = fit_params.pop("domains", None)
920-
936+
921937
if domains is None:
922938
domains = np.zeros(len(X))
923-
939+
924940
domains = self._check_domains(domains)
925941

926942
self.n_sources_ = int(np.max(domains)+1)
927-
943+
928944
sizes = np.array(
929945
[np.sum(domains==dom) for dom in range(self.n_sources_)]+
930946
[len(Xt)])
931-
947+
932948
max_size = np.max(sizes)
933949
repeats = np.ceil(max_size/sizes)
934950

951+
# Split if validation_split
952+
# if validation_data is None and validation_split>0.:
953+
# frac = int(len(dataset)*validation_split)
954+
# validation_data = dataset.take(frac)
955+
# dataset = dataset.skip(frac)
956+
957+
935958
dataset_X = tf.data.Dataset.zip(tuple(
936959
tf.data.Dataset.from_tensor_slices(X[domains==dom]).repeat(repeats[dom])
937960
for dom in range(self.n_sources_))+
@@ -944,15 +967,6 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
944967
(tf.data.Dataset.from_tensor_slices(yt).repeat(repeats[-1]),)
945968
)
946969

947-
948-
# 3. Get Fit params
949-
fit_params = self._filter_params(super().fit, fit_params)
950-
951-
verbose = fit_params.get("verbose", 1)
952-
epochs = fit_params.get("epochs", 1)
953-
batch_size = fit_params.pop("batch_size", 32)
954-
shuffle = fit_params.pop("shuffle", True)
955-
956970
# 4. Pretraining
957971
if not hasattr(self, "pretrain_"):
958972
if not hasattr(self, "pretrain"):
@@ -993,7 +1007,21 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
9931007

9941008
self._initialize_pretain_networks()
9951009

996-
# 5. Training
1010+
# 5. Define validation Set
1011+
if isinstance(validation_data, tuple):
1012+
X_val = validation_data[0]
1013+
y_val = validation_data[1]
1014+
1015+
validation_data = tf.data.Dataset.zip(
1016+
(tf.data.Dataset.from_tensor_slices(X_val),
1017+
tf.data.Dataset.from_tensor_slices(y_val))
1018+
)
1019+
if shuffle:
1020+
validation_data = validation_data.shuffle(buffer_size=1024).batch(batch_size)
1021+
else:
1022+
validation_data = validation_data.batch(batch_size)
1023+
1024+
# 6. Training
9971025
if (not self._is_compiled) or (self.pretrain_):
9981026
self.compile()
9991027

@@ -1004,12 +1032,12 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
10041032
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).shuffle(buffer_size=1024).batch(batch_size)
10051033
else:
10061034
dataset = tf.data.Dataset.zip((dataset_X, dataset_y)).batch(batch_size)
1007-
1035+
10081036
self.pretrain_ = False
10091037
self.steps_ = tf.Variable(0.)
10101038
self.total_steps_ = float(np.ceil(max_size/batch_size)*epochs)
10111039

1012-
hist = super().fit(dataset, **fit_params)
1040+
hist = super().fit(dataset, validation_data=validation_data, **fit_params)
10131041

10141042
for k, v in hist.history.items():
10151043
self.history_[k] = self.history_.get(k, []) + v

0 commit comments

Comments
 (0)