@@ -367,7 +367,7 @@ def _save_validation_data(self, Xs, Xt):
367
367
else :
368
368
self .Xs_ = Xs
369
369
self .Xt_ = Xt
370
- self .src_index_ = np . arange ( len ( Xs ))
370
+ self .src_index_ = None
371
371
372
372
373
373
def _get_target_data (self , X , y ):
@@ -857,7 +857,7 @@ def __init__(self,
857
857
self ._self_setattr_tracking = True
858
858
859
859
860
- def fit (self , X , y , Xt = None , yt = None , domains = None , ** fit_params ):
860
+ def fit (self , X , y = None , Xt = None , yt = None , domains = None , ** fit_params ):
861
861
"""
862
862
Fit Model. Note that ``fit`` does not reset
863
863
the model but extend the training.
@@ -867,7 +867,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
867
867
X : array or Tensor
868
868
Source input data.
869
869
870
- y : array or Tensor
870
+ y : array or Tensor (default=None)
871
871
Source output data.
872
872
873
873
Xt : array (default=None)
@@ -896,7 +896,18 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
896
896
if not hasattr (self , "_is_fitted" ):
897
897
self ._is_fitted = True
898
898
self ._initialize_networks ()
899
- self ._initialize_weights (X .shape [1 :])
899
+ if isinstance (X , tf .data .Dataset ):
900
+ first_elem = next (iter (X ))
901
+ if (not isinstance (first_elem , tuple ) or
902
+ not len (first_elem )== 2 ):
903
+ raise ValueError ("When first argument is a dataset. "
904
+ "It should return (x, y) tuples." )
905
+ else :
906
+ shape = first_elem [0 ].shape
907
+ else :
908
+ shape = X .shape [1 :]
909
+ print (shape )
910
+ self ._initialize_weights (shape )
900
911
901
912
# 2. Get Fit params
902
913
fit_params = self ._filter_params (super ().fit , fit_params )
@@ -909,65 +920,96 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
909
920
validation_split = fit_params .pop ("validation_split" , 0. )
910
921
validation_batch_size = fit_params .pop ("validation_batch_size" , batch_size )
911
922
912
- # 3. Prepare dataset
923
+ # 3. Prepare datasets
913
924
914
925
### 3.1 Source
915
926
if not isinstance (X , tf .data .Dataset ):
916
927
check_arrays (X , y )
917
928
if len (y .shape ) <= 1 :
918
929
y = y .reshape (- 1 , 1 )
919
930
931
+ # Single source
932
+ if domains is None :
933
+ self .n_sources_ = 1
934
+
935
+ dataset_Xs = tf .data .Dataset .from_tensor_slices (X )
936
+ dataset_ys = tf .data .Dataset .from_tensor_slices (y )
937
+
938
+ # Multisource
939
+ else :
940
+ domains = self ._check_domains (domains )
941
+ self .n_sources_ = int (np .max (domains )+ 1 )
942
+
943
+ sizes = [np .sum (domains == dom )
944
+ for dom in range (self .n_sources_ )]
945
+
946
+ max_size = np .max (sizes )
947
+ repeats = np .ceil (max_size / sizes )
948
+
949
+ dataset_Xs = tf .data .Dataset .zip (tuple (
950
+ tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
951
+ for dom in range (self .n_sources_ ))
952
+ )
953
+
954
+ dataset_ys = tf .data .Dataset .zip (tuple (
955
+ tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
956
+ for dom in range (self .n_sources_ ))
957
+ )
958
+
959
+ dataset_src = tf .data .Dataset .zip ((dataset_Xs , dataset_ys ))
960
+
961
+ else :
962
+ dataset_src = X
963
+
920
964
### 3.2 Target
921
965
Xt , yt = self ._get_target_data (Xt , yt )
922
966
if not isinstance (Xt , tf .data .Dataset ):
923
967
if yt is None :
924
- yt = y
925
968
check_array (Xt , ensure_2d = True , allow_nd = True )
969
+ dataset_tgt = tf .data .Dataset .from_tensor_slices (Xt )
970
+
926
971
else :
927
972
check_arrays (Xt , yt )
928
973
929
- if len (yt .shape ) <= 1 :
930
- yt = yt .reshape (- 1 , 1 )
974
+ if len (yt .shape ) <= 1 :
975
+ yt = yt .reshape (- 1 , 1 )
976
+
977
+ dataset_Xt = tf .data .Dataset .from_tensor_slices (Xt )
978
+ dataset_yt = tf .data .Dataset .from_tensor_slices (yt )
979
+ dataset_tgt = tf .data .Dataset .zip ((dataset_Xt , dataset_yt ))
980
+
981
+ else :
982
+ dataset_tgt = Xt
931
983
932
984
self ._save_validation_data (X , Xt )
933
985
934
- ### 3.3 Domains
935
- domains = fit_params .pop ("domains" , None )
936
-
937
- if domains is None :
938
- domains = np .zeros (len (X ))
939
-
940
- domains = self ._check_domains (domains )
941
-
942
- self .n_sources_ = int (np .max (domains )+ 1 )
943
-
944
- sizes = np .array (
945
- [np .sum (domains == dom ) for dom in range (self .n_sources_ )]+
946
- [len (Xt )])
947
-
948
- max_size = np .max (sizes )
949
- repeats = np .ceil (max_size / sizes )
950
-
951
- # Split if validation_split
952
- # if validation_data is None and validation_split>0.:
953
- # frac = int(len(dataset)*validation_split)
954
- # validation_data = dataset.take(frac)
955
- # dataset = dataset.skip(frac)
956
-
986
+ # 4. Get validation data
987
+ validation_data = self ._check_validation_data (validation_data ,
988
+ validation_batch_size ,
989
+ shuffle )
957
990
958
- dataset_X = tf .data .Dataset .zip (tuple (
959
- tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
960
- for dom in range (self .n_sources_ ))+
961
- (tf .data .Dataset .from_tensor_slices (Xt ).repeat (repeats [- 1 ]),)
962
- )
963
-
964
- dataset_y = tf .data .Dataset .zip (tuple (
965
- tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
966
- for dom in range (self .n_sources_ ))+
967
- (tf .data .Dataset .from_tensor_slices (yt ).repeat (repeats [- 1 ]),)
968
- )
991
+ if validation_data is None and validation_split > 0. :
992
+ if shuffle :
993
+ dataset_src = dataset_src .shuffle (buffer_size = 1024 )
994
+ frac = int (len (dataset_src )* validation_split )
995
+ validation_data = dataset_src .take (frac )
996
+ dataset_src = dataset_src .skip (frac )
997
+ validation_data = validation_data .batch (batch_size )
998
+
999
+ # 5. Set datasets
1000
+ try :
1001
+ max_size = max (len (dataset_src ), len (dataset_tgt ))
1002
+ repeat_src = np .ceil (max_size / len (dataset_src ))
1003
+ repeat_tgt = np .ceil (max_size / len (dataset_tgt ))
1004
+
1005
+ dataset_src = dataset_src .repeat (repeat_src )
1006
+ dataset_tgt = dataset_tgt .repeat (repeat_tgt )
1007
+
1008
+ self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1009
+ except :
1010
+ pass
969
1011
970
- # 4 . Pretraining
1012
+ # 5 . Pretraining
971
1013
if not hasattr (self , "pretrain_" ):
972
1014
if not hasattr (self , "pretrain" ):
973
1015
self .pretrain_ = False
@@ -994,32 +1036,22 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
994
1036
pre_epochs = prefit_params .pop ("epochs" , epochs )
995
1037
pre_batch_size = prefit_params .pop ("batch_size" , batch_size )
996
1038
pre_shuffle = prefit_params .pop ("shuffle" , shuffle )
1039
+ prefit_params .pop ("validation_data" , None )
1040
+ prefit_params .pop ("validation_split" , None )
1041
+ prefit_params .pop ("validation_batch_size" , None )
997
1042
998
1043
if pre_shuffle :
999
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
1044
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
1000
1045
else :
1001
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (pre_batch_size )
1046
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (pre_batch_size )
1002
1047
1003
- hist = super ().fit (dataset , epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
1048
+ hist = super ().fit (dataset , validation_data = validation_data ,
1049
+ epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
1004
1050
1005
1051
for k , v in hist .history .items ():
1006
1052
self .pretrain_history_ [k ] = self .pretrain_history_ .get (k , []) + v
1007
1053
1008
1054
self ._initialize_pretain_networks ()
1009
-
1010
- # 5. Define validation Set
1011
- if isinstance (validation_data , tuple ):
1012
- X_val = validation_data [0 ]
1013
- y_val = validation_data [1 ]
1014
-
1015
- validation_data = tf .data .Dataset .zip (
1016
- (tf .data .Dataset .from_tensor_slices (X_val ),
1017
- tf .data .Dataset .from_tensor_slices (y_val ))
1018
- )
1019
- if shuffle :
1020
- validation_data = validation_data .shuffle (buffer_size = 1024 ).batch (batch_size )
1021
- else :
1022
- validation_data = validation_data .batch (batch_size )
1023
1055
1024
1056
# 6. Training
1025
1057
if (not self ._is_compiled ) or (self .pretrain_ ):
@@ -1029,13 +1061,12 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
1029
1061
self .history_ = {}
1030
1062
1031
1063
if shuffle :
1032
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (batch_size )
1064
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (batch_size )
1033
1065
else :
1034
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (batch_size )
1066
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (batch_size )
1035
1067
1036
1068
self .pretrain_ = False
1037
1069
self .steps_ = tf .Variable (0. )
1038
- self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1039
1070
1040
1071
hist = super ().fit (dataset , validation_data = validation_data , ** fit_params )
1041
1072
@@ -1227,10 +1258,6 @@ def train_step(self, data):
1227
1258
# Unpack the data.
1228
1259
Xs , Xt , ys , yt = self ._unpack_data (data )
1229
1260
1230
- # Single source
1231
- Xs = Xs [0 ]
1232
- ys = ys [0 ]
1233
-
1234
1261
# Run forward pass.
1235
1262
with tf .GradientTape () as tape :
1236
1263
y_pred = self (Xs , training = True )
@@ -1418,6 +1445,22 @@ def score_estimator(self, X, y, sample_weight=None):
1418
1445
if isinstance (score , (tuple , list )):
1419
1446
score = score [0 ]
1420
1447
return score
1448
+
1449
+
1450
+ def _check_validation_data (self , validation_data , batch_size , shuffle ):
1451
+ if isinstance (validation_data , tuple ):
1452
+ X_val = validation_data [0 ]
1453
+ y_val = validation_data [1 ]
1454
+
1455
+ validation_data = tf .data .Dataset .zip (
1456
+ (tf .data .Dataset .from_tensor_slices (X_val ),
1457
+ tf .data .Dataset .from_tensor_slices (y_val ))
1458
+ )
1459
+ if shuffle :
1460
+ validation_data = validation_data .shuffle (buffer_size = 1024 ).batch (batch_size )
1461
+ else :
1462
+ validation_data = validation_data .batch (batch_size )
1463
+ return validation_data
1421
1464
1422
1465
1423
1466
def _get_legal_params (self , params ):
@@ -1467,13 +1510,17 @@ def _initialize_weights(self, shape_X):
1467
1510
1468
1511
1469
1512
def _unpack_data (self , data ):
1470
- data_X = data [0 ]
1471
- data_y = data [1 ]
1472
- Xs = data_X [:- 1 ]
1473
- Xt = data_X [- 1 ]
1474
- ys = data_y [:- 1 ]
1475
- yt = data_y [- 1 ]
1476
- return Xs , Xt , ys , ys
1513
+ data_src = data [0 ]
1514
+ data_tgt = data [1 ]
1515
+ Xs = data_src [0 ]
1516
+ ys = data_src [1 ]
1517
+ if isinstance (data_tgt , tuple ):
1518
+ Xt = data_tgt [0 ]
1519
+ yt = data_tgt [1 ]
1520
+ return Xs , Xt , ys , yt
1521
+ else :
1522
+ Xt = data_tgt
1523
+ return Xs , Xt , ys , None
1477
1524
1478
1525
1479
1526
def _get_disc_metrics (self , ys_disc , yt_disc ):
0 commit comments