@@ -367,7 +367,7 @@ def _save_validation_data(self, Xs, Xt):
367
367
else :
368
368
self .Xs_ = Xs
369
369
self .Xt_ = Xt
370
- self .src_index_ = np . arange ( len ( Xs ))
370
+ self .src_index_ = None
371
371
372
372
373
373
def _get_target_data (self , X , y ):
@@ -857,7 +857,7 @@ def __init__(self,
857
857
self ._self_setattr_tracking = True
858
858
859
859
860
- def fit (self , X , y , Xt = None , yt = None , domains = None , ** fit_params ):
860
+ def fit (self , X , y = None , Xt = None , yt = None , domains = None , ** fit_params ):
861
861
"""
862
862
Fit Model. Note that ``fit`` does not reset
863
863
the model but extend the training.
@@ -867,7 +867,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
867
867
X : array or Tensor
868
868
Source input data.
869
869
870
- y : array or Tensor
870
+ y : array or Tensor (default=None)
871
871
Source output data.
872
872
873
873
Xt : array (default=None)
@@ -896,64 +896,120 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
896
896
if not hasattr (self , "_is_fitted" ):
897
897
self ._is_fitted = True
898
898
self ._initialize_networks ()
899
- self ._initialize_weights (X .shape [1 :])
899
+ if isinstance (X , tf .data .Dataset ):
900
+ first_elem = next (iter (X ))
901
+ if (not isinstance (first_elem , tuple ) or
902
+ not len (first_elem )== 2 ):
903
+ raise ValueError ("When first argument is a dataset. "
904
+ "It should return (x, y) tuples." )
905
+ else :
906
+ shape = first_elem [0 ].shape
907
+ else :
908
+ shape = X .shape [1 :]
909
+ print (shape )
910
+ self ._initialize_weights (shape )
911
+
912
+ # 2. Get Fit params
913
+ fit_params = self ._filter_params (super ().fit , fit_params )
914
+
915
+ verbose = fit_params .get ("verbose" , 1 )
916
+ epochs = fit_params .get ("epochs" , 1 )
917
+ batch_size = fit_params .pop ("batch_size" , 32 )
918
+ shuffle = fit_params .pop ("shuffle" , True )
919
+ validation_data = fit_params .pop ("validation_data" , None )
920
+ validation_split = fit_params .pop ("validation_split" , 0. )
921
+ validation_batch_size = fit_params .pop ("validation_batch_size" , batch_size )
922
+
923
+ # 3. Prepare datasets
900
924
901
- # 2. Prepare dataset
925
+ ### 3.1 Source
926
+ if not isinstance (X , tf .data .Dataset ):
927
+ check_arrays (X , y )
928
+ if len (y .shape ) <= 1 :
929
+ y = y .reshape (- 1 , 1 )
930
+
931
+ # Single source
932
+ if domains is None :
933
+ self .n_sources_ = 1
934
+
935
+ dataset_Xs = tf .data .Dataset .from_tensor_slices (X )
936
+ dataset_ys = tf .data .Dataset .from_tensor_slices (y )
937
+
938
+ # Multisource
939
+ else :
940
+ domains = self ._check_domains (domains )
941
+ self .n_sources_ = int (np .max (domains )+ 1 )
942
+
943
+ sizes = [np .sum (domains == dom )
944
+ for dom in range (self .n_sources_ )]
945
+
946
+ max_size = np .max (sizes )
947
+ repeats = np .ceil (max_size / sizes )
948
+
949
+ dataset_Xs = tf .data .Dataset .zip (tuple (
950
+ tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
951
+ for dom in range (self .n_sources_ ))
952
+ )
953
+
954
+ dataset_ys = tf .data .Dataset .zip (tuple (
955
+ tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
956
+ for dom in range (self .n_sources_ ))
957
+ )
958
+
959
+ dataset_src = tf .data .Dataset .zip ((dataset_Xs , dataset_ys ))
960
+
961
+ else :
962
+ dataset_src = X
963
+
964
+ ### 3.2 Target
902
965
Xt , yt = self ._get_target_data (Xt , yt )
966
+ if not isinstance (Xt , tf .data .Dataset ):
967
+ if yt is None :
968
+ check_array (Xt , ensure_2d = True , allow_nd = True )
969
+ dataset_tgt = tf .data .Dataset .from_tensor_slices (Xt )
903
970
904
- check_arrays (X , y )
905
- if len (y .shape ) <= 1 :
906
- y = y .reshape (- 1 , 1 )
971
+ else :
972
+ check_arrays (Xt , yt )
973
+
974
+ if len (yt .shape ) <= 1 :
975
+ yt = yt .reshape (- 1 , 1 )
976
+
977
+ dataset_Xt = tf .data .Dataset .from_tensor_slices (Xt )
978
+ dataset_yt = tf .data .Dataset .from_tensor_slices (yt )
979
+ dataset_tgt = tf .data .Dataset .zip ((dataset_Xt , dataset_yt ))
907
980
908
- if yt is None :
909
- yt = y
910
- check_array (Xt , ensure_2d = True , allow_nd = True )
911
981
else :
912
- check_arrays (Xt , yt )
913
-
914
- if len (yt .shape ) <= 1 :
915
- yt = yt .reshape (- 1 , 1 )
982
+ dataset_tgt = Xt
916
983
917
984
self ._save_validation_data (X , Xt )
918
985
919
- domains = fit_params .pop ("domains" , None )
920
-
921
- if domains is None :
922
- domains = np .zeros (len (X ))
923
-
924
- domains = self ._check_domains (domains )
925
-
926
- self .n_sources_ = int (np .max (domains )+ 1 )
927
-
928
- sizes = np .array (
929
- [np .sum (domains == dom ) for dom in range (self .n_sources_ )]+
930
- [len (Xt )])
931
-
932
- max_size = np .max (sizes )
933
- repeats = np .ceil (max_size / sizes )
934
-
935
- dataset_X = tf .data .Dataset .zip (tuple (
936
- tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
937
- for dom in range (self .n_sources_ ))+
938
- (tf .data .Dataset .from_tensor_slices (Xt ).repeat (repeats [- 1 ]),)
939
- )
940
-
941
- dataset_y = tf .data .Dataset .zip (tuple (
942
- tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
943
- for dom in range (self .n_sources_ ))+
944
- (tf .data .Dataset .from_tensor_slices (yt ).repeat (repeats [- 1 ]),)
945
- )
946
-
947
-
948
- # 3. Get Fit params
949
- fit_params = self ._filter_params (super ().fit , fit_params )
950
-
951
- verbose = fit_params .get ("verbose" , 1 )
952
- epochs = fit_params .get ("epochs" , 1 )
953
- batch_size = fit_params .pop ("batch_size" , 32 )
954
- shuffle = fit_params .pop ("shuffle" , True )
986
+ # 4. Get validation data
987
+ validation_data = self ._check_validation_data (validation_data ,
988
+ validation_batch_size ,
989
+ shuffle )
990
+
991
+ if validation_data is None and validation_split > 0. :
992
+ if shuffle :
993
+ dataset_src = dataset_src .shuffle (buffer_size = 1024 )
994
+ frac = int (len (dataset_src )* validation_split )
995
+ validation_data = dataset_src .take (frac )
996
+ dataset_src = dataset_src .skip (frac )
997
+ validation_data = validation_data .batch (batch_size )
998
+
999
+ # 5. Set datasets
1000
+ try :
1001
+ max_size = max (len (dataset_src ), len (dataset_tgt ))
1002
+ repeat_src = np .ceil (max_size / len (dataset_src ))
1003
+ repeat_tgt = np .ceil (max_size / len (dataset_tgt ))
1004
+
1005
+ dataset_src = dataset_src .repeat (repeat_src )
1006
+ dataset_tgt = dataset_tgt .repeat (repeat_tgt )
1007
+
1008
+ self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1009
+ except :
1010
+ pass
955
1011
956
- # 4 . Pretraining
1012
+ # 5 . Pretraining
957
1013
if not hasattr (self , "pretrain_" ):
958
1014
if not hasattr (self , "pretrain" ):
959
1015
self .pretrain_ = False
@@ -980,36 +1036,39 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
980
1036
pre_epochs = prefit_params .pop ("epochs" , epochs )
981
1037
pre_batch_size = prefit_params .pop ("batch_size" , batch_size )
982
1038
pre_shuffle = prefit_params .pop ("shuffle" , shuffle )
1039
+ prefit_params .pop ("validation_data" , None )
1040
+ prefit_params .pop ("validation_split" , None )
1041
+ prefit_params .pop ("validation_batch_size" , None )
983
1042
984
1043
if pre_shuffle :
985
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
1044
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
986
1045
else :
987
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (pre_batch_size )
1046
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (pre_batch_size )
988
1047
989
- hist = super ().fit (dataset , epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
1048
+ hist = super ().fit (dataset , validation_data = validation_data ,
1049
+ epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
990
1050
991
1051
for k , v in hist .history .items ():
992
1052
self .pretrain_history_ [k ] = self .pretrain_history_ .get (k , []) + v
993
1053
994
1054
self ._initialize_pretain_networks ()
995
-
996
- # 5 . Training
1055
+
1056
+ # 6 . Training
997
1057
if (not self ._is_compiled ) or (self .pretrain_ ):
998
1058
self .compile ()
999
1059
1000
1060
if not hasattr (self , "history_" ):
1001
1061
self .history_ = {}
1002
1062
1003
1063
if shuffle :
1004
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (batch_size )
1064
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (batch_size )
1005
1065
else :
1006
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (batch_size )
1007
-
1066
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (batch_size )
1067
+
1008
1068
self .pretrain_ = False
1009
1069
self .steps_ = tf .Variable (0. )
1010
- self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1011
1070
1012
- hist = super ().fit (dataset , ** fit_params )
1071
+ hist = super ().fit (dataset , validation_data = validation_data , ** fit_params )
1013
1072
1014
1073
for k , v in hist .history .items ():
1015
1074
self .history_ [k ] = self .history_ .get (k , []) + v
@@ -1199,10 +1258,6 @@ def train_step(self, data):
1199
1258
# Unpack the data.
1200
1259
Xs , Xt , ys , yt = self ._unpack_data (data )
1201
1260
1202
- # Single source
1203
- Xs = Xs [0 ]
1204
- ys = ys [0 ]
1205
-
1206
1261
# Run forward pass.
1207
1262
with tf .GradientTape () as tape :
1208
1263
y_pred = self (Xs , training = True )
@@ -1390,6 +1445,22 @@ def score_estimator(self, X, y, sample_weight=None):
1390
1445
if isinstance (score , (tuple , list )):
1391
1446
score = score [0 ]
1392
1447
return score
1448
+
1449
+
1450
+ def _check_validation_data (self , validation_data , batch_size , shuffle ):
1451
+ if isinstance (validation_data , tuple ):
1452
+ X_val = validation_data [0 ]
1453
+ y_val = validation_data [1 ]
1454
+
1455
+ validation_data = tf .data .Dataset .zip (
1456
+ (tf .data .Dataset .from_tensor_slices (X_val ),
1457
+ tf .data .Dataset .from_tensor_slices (y_val ))
1458
+ )
1459
+ if shuffle :
1460
+ validation_data = validation_data .shuffle (buffer_size = 1024 ).batch (batch_size )
1461
+ else :
1462
+ validation_data = validation_data .batch (batch_size )
1463
+ return validation_data
1393
1464
1394
1465
1395
1466
def _get_legal_params (self , params ):
@@ -1439,13 +1510,17 @@ def _initialize_weights(self, shape_X):
1439
1510
1440
1511
1441
1512
def _unpack_data (self , data ):
1442
- data_X = data [0 ]
1443
- data_y = data [1 ]
1444
- Xs = data_X [:- 1 ]
1445
- Xt = data_X [- 1 ]
1446
- ys = data_y [:- 1 ]
1447
- yt = data_y [- 1 ]
1448
- return Xs , Xt , ys , ys
1513
+ data_src = data [0 ]
1514
+ data_tgt = data [1 ]
1515
+ Xs = data_src [0 ]
1516
+ ys = data_src [1 ]
1517
+ if isinstance (data_tgt , tuple ):
1518
+ Xt = data_tgt [0 ]
1519
+ yt = data_tgt [1 ]
1520
+ return Xs , Xt , ys , yt
1521
+ else :
1522
+ Xt = data_tgt
1523
+ return Xs , Xt , ys , None
1449
1524
1450
1525
1451
1526
def _get_disc_metrics (self , ys_disc , yt_disc ):
0 commit comments