@@ -897,41 +897,64 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
897
897
self ._is_fitted = True
898
898
self ._initialize_networks ()
899
899
self ._initialize_weights (X .shape [1 :])
900
+
901
+ # 2. Get Fit params
902
+ fit_params = self ._filter_params (super ().fit , fit_params )
900
903
901
- # 2. Prepare dataset
904
+ verbose = fit_params .get ("verbose" , 1 )
905
+ epochs = fit_params .get ("epochs" , 1 )
906
+ batch_size = fit_params .pop ("batch_size" , 32 )
907
+ shuffle = fit_params .pop ("shuffle" , True )
908
+ validation_data = fit_params .pop ("validation_data" , None )
909
+ validation_split = fit_params .pop ("validation_split" , 0. )
910
+ validation_batch_size = fit_params .pop ("validation_batch_size" , batch_size )
911
+
912
+ # 3. Prepare dataset
913
+
914
+ ### 3.1 Source
915
+ if not isinstance (X , tf .data .Dataset ):
916
+ check_arrays (X , y )
917
+ if len (y .shape ) <= 1 :
918
+ y = y .reshape (- 1 , 1 )
919
+
920
+ ### 3.2 Target
902
921
Xt , yt = self ._get_target_data (Xt , yt )
903
-
904
- check_arrays (X , y )
905
- if len (y .shape ) <= 1 :
906
- y = y .reshape (- 1 , 1 )
922
+ if not isinstance (Xt , tf .data .Dataset ):
923
+ if yt is None :
924
+ yt = y
925
+ check_array (Xt , ensure_2d = True , allow_nd = True )
926
+ else :
927
+ check_arrays (Xt , yt )
907
928
908
- if yt is None :
909
- yt = y
910
- check_array (Xt , ensure_2d = True , allow_nd = True )
911
- else :
912
- check_arrays (Xt , yt )
913
-
914
- if len (yt .shape ) <= 1 :
915
- yt = yt .reshape (- 1 , 1 )
929
+ if len (yt .shape ) <= 1 :
930
+ yt = yt .reshape (- 1 , 1 )
916
931
917
932
self ._save_validation_data (X , Xt )
918
933
934
+ ### 3.3 Domains
919
935
domains = fit_params .pop ("domains" , None )
920
-
936
+
921
937
if domains is None :
922
938
domains = np .zeros (len (X ))
923
-
939
+
924
940
domains = self ._check_domains (domains )
925
941
926
942
self .n_sources_ = int (np .max (domains )+ 1 )
927
-
943
+
928
944
sizes = np .array (
929
945
[np .sum (domains == dom ) for dom in range (self .n_sources_ )]+
930
946
[len (Xt )])
931
-
947
+
932
948
max_size = np .max (sizes )
933
949
repeats = np .ceil (max_size / sizes )
934
950
951
+ # Split if validation_split
952
+ # if validation_data is None and validation_split>0.:
953
+ # frac = int(len(dataset)*validation_split)
954
+ # validation_data = dataset.take(frac)
955
+ # dataset = dataset.skip(frac)
956
+
957
+
935
958
dataset_X = tf .data .Dataset .zip (tuple (
936
959
tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
937
960
for dom in range (self .n_sources_ ))+
@@ -944,15 +967,6 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
944
967
(tf .data .Dataset .from_tensor_slices (yt ).repeat (repeats [- 1 ]),)
945
968
)
946
969
947
-
948
- # 3. Get Fit params
949
- fit_params = self ._filter_params (super ().fit , fit_params )
950
-
951
- verbose = fit_params .get ("verbose" , 1 )
952
- epochs = fit_params .get ("epochs" , 1 )
953
- batch_size = fit_params .pop ("batch_size" , 32 )
954
- shuffle = fit_params .pop ("shuffle" , True )
955
-
956
970
# 4. Pretraining
957
971
if not hasattr (self , "pretrain_" ):
958
972
if not hasattr (self , "pretrain" ):
@@ -993,7 +1007,21 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
993
1007
994
1008
self ._initialize_pretain_networks ()
995
1009
996
- # 5. Training
1010
+ # 5. Define validation Set
1011
+ if isinstance (validation_data , tuple ):
1012
+ X_val = validation_data [0 ]
1013
+ y_val = validation_data [1 ]
1014
+
1015
+ validation_data = tf .data .Dataset .zip (
1016
+ (tf .data .Dataset .from_tensor_slices (X_val ),
1017
+ tf .data .Dataset .from_tensor_slices (y_val ))
1018
+ )
1019
+ if shuffle :
1020
+ validation_data = validation_data .shuffle (buffer_size = 1024 ).batch (batch_size )
1021
+ else :
1022
+ validation_data = validation_data .batch (batch_size )
1023
+
1024
+ # 6. Training
997
1025
if (not self ._is_compiled ) or (self .pretrain_ ):
998
1026
self .compile ()
999
1027
@@ -1004,12 +1032,12 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
1004
1032
dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (batch_size )
1005
1033
else :
1006
1034
dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (batch_size )
1007
-
1035
+
1008
1036
self .pretrain_ = False
1009
1037
self .steps_ = tf .Variable (0. )
1010
1038
self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1011
1039
1012
- hist = super ().fit (dataset , ** fit_params )
1040
+ hist = super ().fit (dataset , validation_data = validation_data , ** fit_params )
1013
1041
1014
1042
for k , v in hist .history .items ():
1015
1043
self .history_ [k ] = self .history_ .get (k , []) + v
0 commit comments