@@ -367,7 +367,7 @@ def _save_validation_data(self, Xs, Xt):
367
367
else :
368
368
self .Xs_ = Xs
369
369
self .Xt_ = Xt
370
- self .src_index_ = np . arange ( len ( Xs ))
370
+ self .src_index_ = None
371
371
372
372
373
373
def _get_target_data (self , X , y ):
@@ -458,7 +458,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
458
458
if yt is not None :
459
459
Xt , yt = check_arrays (Xt , yt )
460
460
else :
461
- Xt = check_array (Xt )
461
+ Xt = check_array (Xt , ensure_2d = True , allow_nd = True )
462
462
set_random_seed (self .random_state )
463
463
464
464
self ._save_validation_data (X , Xt )
@@ -857,7 +857,7 @@ def __init__(self,
857
857
self ._self_setattr_tracking = True
858
858
859
859
860
- def fit (self , X , y , Xt = None , yt = None , domains = None , ** fit_params ):
860
+ def fit (self , X , y = None , Xt = None , yt = None , domains = None , ** fit_params ):
861
861
"""
862
862
Fit Model. Note that ``fit`` does not reset
863
863
the model but extend the training.
@@ -867,7 +867,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
867
867
X : array or Tensor
868
868
Source input data.
869
869
870
- y : array or Tensor
870
+ y : array or Tensor (default=None)
871
871
Source output data.
872
872
873
873
Xt : array (default=None)
@@ -889,71 +889,126 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
889
889
Returns
890
890
-------
891
891
self : returns an instance of self
892
- """
892
+ """
893
893
set_random_seed (self .random_state )
894
894
895
895
# 1. Initialize networks
896
896
if not hasattr (self , "_is_fitted" ):
897
897
self ._is_fitted = True
898
898
self ._initialize_networks ()
899
- self ._initialize_weights (X .shape [1 :])
899
+ if isinstance (X , tf .data .Dataset ):
900
+ first_elem = next (iter (X ))
901
+ if (not isinstance (first_elem , tuple ) or
902
+ not len (first_elem )== 2 ):
903
+ raise ValueError ("When first argument is a dataset. "
904
+ "It should return (x, y) tuples." )
905
+ else :
906
+ shape = first_elem [0 ].shape
907
+ else :
908
+ shape = X .shape [1 :]
909
+ self ._initialize_weights (shape )
910
+
911
+ # 2. Get Fit params
912
+ fit_params = self ._filter_params (super ().fit , fit_params )
900
913
901
- # 2. Prepare dataset
914
+ verbose = fit_params .get ("verbose" , 1 )
915
+ epochs = fit_params .get ("epochs" , 1 )
916
+ batch_size = fit_params .pop ("batch_size" , 32 )
917
+ shuffle = fit_params .pop ("shuffle" , True )
918
+ validation_data = fit_params .pop ("validation_data" , None )
919
+ validation_split = fit_params .pop ("validation_split" , 0. )
920
+ validation_batch_size = fit_params .pop ("validation_batch_size" , batch_size )
921
+
922
+ # 3. Prepare datasets
923
+
924
+ ### 3.1 Source
925
+ if not isinstance (X , tf .data .Dataset ):
926
+ check_arrays (X , y )
927
+ if len (y .shape ) <= 1 :
928
+ y = y .reshape (- 1 , 1 )
929
+
930
+ # Single source
931
+ if domains is None :
932
+ self .n_sources_ = 1
933
+
934
+ dataset_Xs = tf .data .Dataset .from_tensor_slices (X )
935
+ dataset_ys = tf .data .Dataset .from_tensor_slices (y )
936
+
937
+ # Multisource
938
+ else :
939
+ domains = self ._check_domains (domains )
940
+ self .n_sources_ = int (np .max (domains )+ 1 )
941
+
942
+ sizes = [np .sum (domains == dom )
943
+ for dom in range (self .n_sources_ )]
944
+
945
+ max_size = np .max (sizes )
946
+ repeats = np .ceil (max_size / sizes )
947
+
948
+ dataset_Xs = tf .data .Dataset .zip (tuple (
949
+ tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
950
+ for dom in range (self .n_sources_ ))
951
+ )
952
+
953
+ dataset_ys = tf .data .Dataset .zip (tuple (
954
+ tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
955
+ for dom in range (self .n_sources_ ))
956
+ )
957
+
958
+ dataset_src = tf .data .Dataset .zip ((dataset_Xs , dataset_ys ))
959
+
960
+ else :
961
+ dataset_src = X
962
+
963
+ ### 3.2 Target
902
964
Xt , yt = self ._get_target_data (Xt , yt )
965
+ if not isinstance (Xt , tf .data .Dataset ):
966
+ if yt is None :
967
+ check_array (Xt , ensure_2d = True , allow_nd = True )
968
+ dataset_tgt = tf .data .Dataset .from_tensor_slices (Xt )
903
969
904
- check_arrays (X , y )
905
- if len (y .shape ) <= 1 :
906
- y = y .reshape (- 1 , 1 )
970
+ else :
971
+ check_arrays (Xt , yt )
972
+
973
+ if len (yt .shape ) <= 1 :
974
+ yt = yt .reshape (- 1 , 1 )
975
+
976
+ dataset_Xt = tf .data .Dataset .from_tensor_slices (Xt )
977
+ dataset_yt = tf .data .Dataset .from_tensor_slices (yt )
978
+ dataset_tgt = tf .data .Dataset .zip ((dataset_Xt , dataset_yt ))
907
979
908
- if yt is None :
909
- yt = y
910
- check_array (Xt )
911
980
else :
912
- check_arrays (Xt , yt )
913
-
914
- if len (yt .shape ) <= 1 :
915
- yt = yt .reshape (- 1 , 1 )
981
+ dataset_tgt = Xt
916
982
917
983
self ._save_validation_data (X , Xt )
918
984
919
- domains = fit_params .pop ("domains" , None )
920
-
921
- if domains is None :
922
- domains = np .zeros (len (X ))
923
-
924
- domains = self ._check_domains (domains )
925
-
926
- self .n_sources_ = int (np .max (domains )+ 1 )
927
-
928
- sizes = np .array (
929
- [np .sum (domains == dom ) for dom in range (self .n_sources_ )]+
930
- [len (Xt )])
931
-
932
- max_size = np .max (sizes )
933
- repeats = np .ceil (max_size / sizes )
934
-
935
- dataset_X = tf .data .Dataset .zip (tuple (
936
- tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
937
- for dom in range (self .n_sources_ ))+
938
- (tf .data .Dataset .from_tensor_slices (Xt ).repeat (repeats [- 1 ]),)
939
- )
940
-
941
- dataset_y = tf .data .Dataset .zip (tuple (
942
- tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
943
- for dom in range (self .n_sources_ ))+
944
- (tf .data .Dataset .from_tensor_slices (yt ).repeat (repeats [- 1 ]),)
945
- )
946
-
947
-
948
- # 3. Get Fit params
949
- fit_params = self ._filter_params (super ().fit , fit_params )
950
-
951
- verbose = fit_params .get ("verbose" , 1 )
952
- epochs = fit_params .get ("epochs" , 1 )
953
- batch_size = fit_params .pop ("batch_size" , 32 )
954
- shuffle = fit_params .pop ("shuffle" , True )
985
+ # 4. Get validation data
986
+ # validation_data = self._check_validation_data(validation_data,
987
+ # validation_batch_size,
988
+ # shuffle)
989
+
990
+ if validation_data is None and validation_split > 0. :
991
+ if shuffle :
992
+ dataset_src = dataset_src .shuffle (buffer_size = 1024 )
993
+ frac = int (len (dataset_src )* validation_split )
994
+ validation_data = dataset_src .take (frac )
995
+ dataset_src = dataset_src .skip (frac )
996
+ validation_data = validation_data .batch (batch_size )
997
+
998
+ # 5. Set datasets
999
+ try :
1000
+ max_size = max (len (dataset_src ), len (dataset_tgt ))
1001
+ repeat_src = np .ceil (max_size / len (dataset_src ))
1002
+ repeat_tgt = np .ceil (max_size / len (dataset_tgt ))
1003
+
1004
+ dataset_src = dataset_src .repeat (repeat_src )
1005
+ dataset_tgt = dataset_tgt .repeat (repeat_tgt )
1006
+
1007
+ self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1008
+ except :
1009
+ pass
955
1010
956
- # 4 . Pretraining
1011
+ # 5 . Pretraining
957
1012
if not hasattr (self , "pretrain_" ):
958
1013
if not hasattr (self , "pretrain" ):
959
1014
self .pretrain_ = False
@@ -980,36 +1035,39 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
980
1035
pre_epochs = prefit_params .pop ("epochs" , epochs )
981
1036
pre_batch_size = prefit_params .pop ("batch_size" , batch_size )
982
1037
pre_shuffle = prefit_params .pop ("shuffle" , shuffle )
1038
+ prefit_params .pop ("validation_data" , None )
1039
+ prefit_params .pop ("validation_split" , None )
1040
+ prefit_params .pop ("validation_batch_size" , None )
983
1041
984
1042
if pre_shuffle :
985
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
1043
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
986
1044
else :
987
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (pre_batch_size )
1045
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (pre_batch_size )
988
1046
989
- hist = super ().fit (dataset , epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
1047
+ hist = super ().fit (dataset , validation_data = validation_data ,
1048
+ epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
990
1049
991
1050
for k , v in hist .history .items ():
992
1051
self .pretrain_history_ [k ] = self .pretrain_history_ .get (k , []) + v
993
1052
994
1053
self ._initialize_pretain_networks ()
995
-
996
- # 5. Training
1054
+
1055
+ # 6. Compile
997
1056
if (not self ._is_compiled ) or (self .pretrain_ ):
998
1057
self .compile ()
999
1058
1000
1059
if not hasattr (self , "history_" ):
1001
1060
self .history_ = {}
1002
1061
1062
+ # .7 Training
1003
1063
if shuffle :
1004
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (batch_size )
1064
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (batch_size )
1005
1065
else :
1006
- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (batch_size )
1007
-
1066
+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (batch_size )
1067
+
1008
1068
self .pretrain_ = False
1009
- self .steps_ = tf .Variable (0. )
1010
- self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1011
1069
1012
- hist = super ().fit (dataset , ** fit_params )
1070
+ hist = super ().fit (dataset , validation_data = validation_data , ** fit_params )
1013
1071
1014
1072
for k , v in hist .history .items ():
1015
1073
self .history_ [k ] = self .history_ .get (k , []) + v
@@ -1188,6 +1246,12 @@ def compile(self,
1188
1246
super ().compile (
1189
1247
** compile_params
1190
1248
)
1249
+
1250
+ # Set optimizer for encoder and discriminator
1251
+ if not hasattr (self , "optimizer_enc" ):
1252
+ self .optimizer_enc = self .optimizer
1253
+ if not hasattr (self , "optimizer_disc" ):
1254
+ self .optimizer_disc = self .optimizer
1191
1255
1192
1256
1193
1257
def call (self , inputs ):
@@ -1199,10 +1263,6 @@ def train_step(self, data):
1199
1263
# Unpack the data.
1200
1264
Xs , Xt , ys , yt = self ._unpack_data (data )
1201
1265
1202
- # Single source
1203
- Xs = Xs [0 ]
1204
- ys = ys [0 ]
1205
-
1206
1266
# Run forward pass.
1207
1267
with tf .GradientTape () as tape :
1208
1268
y_pred = self (Xs , training = True )
@@ -1376,7 +1436,7 @@ def score_estimator(self, X, y, sample_weight=None):
1376
1436
score : float
1377
1437
Score.
1378
1438
"""
1379
- if np .prod (X .shape ) <= 10 ** 8 :
1439
+ if hasattr ( X , "shape" ) and np .prod (X .shape ) <= 10 ** 8 :
1380
1440
score = self .evaluate (
1381
1441
X , y ,
1382
1442
sample_weight = sample_weight ,
@@ -1390,6 +1450,22 @@ def score_estimator(self, X, y, sample_weight=None):
1390
1450
if isinstance (score , (tuple , list )):
1391
1451
score = score [0 ]
1392
1452
return score
1453
+
1454
+
1455
+ # def _check_validation_data(self, validation_data, batch_size, shuffle):
1456
+ # if isinstance(validation_data, tuple):
1457
+ # X_val = validation_data[0]
1458
+ # y_val = validation_data[1]
1459
+
1460
+ # validation_data = tf.data.Dataset.zip(
1461
+ # (tf.data.Dataset.from_tensor_slices(X_val),
1462
+ # tf.data.Dataset.from_tensor_slices(y_val))
1463
+ # )
1464
+ # if shuffle:
1465
+ # validation_data = validation_data.shuffle(buffer_size=1024).batch(batch_size)
1466
+ # else:
1467
+ # validation_data = validation_data.batch(batch_size)
1468
+ # return validation_data
1393
1469
1394
1470
1395
1471
def _get_legal_params (self , params ):
@@ -1405,7 +1481,7 @@ def _get_legal_params(self, params):
1405
1481
if (optimizer is not None ) and (not isinstance (optimizer , str )):
1406
1482
legal_params_fct .append (optimizer .__init__ )
1407
1483
1408
- legal_params = ["domain" , "val_sample_size" ]
1484
+ legal_params = ["domain" , "val_sample_size" , "optimizer_enc" , "optimizer_disc" ]
1409
1485
for func in legal_params_fct :
1410
1486
args = [
1411
1487
p .name
@@ -1439,13 +1515,17 @@ def _initialize_weights(self, shape_X):
1439
1515
1440
1516
1441
1517
def _unpack_data (self , data ):
1442
- data_X = data [0 ]
1443
- data_y = data [1 ]
1444
- Xs = data_X [:- 1 ]
1445
- Xt = data_X [- 1 ]
1446
- ys = data_y [:- 1 ]
1447
- yt = data_y [- 1 ]
1448
- return Xs , Xt , ys , ys
1518
+ data_src = data [0 ]
1519
+ data_tgt = data [1 ]
1520
+ Xs = data_src [0 ]
1521
+ ys = data_src [1 ]
1522
+ if isinstance (data_tgt , tuple ):
1523
+ Xt = data_tgt [0 ]
1524
+ yt = data_tgt [1 ]
1525
+ return Xs , Xt , ys , yt
1526
+ else :
1527
+ Xt = data_tgt
1528
+ return Xs , Xt , ys , None
1449
1529
1450
1530
1451
1531
def _get_disc_metrics (self , ys_disc , yt_disc ):
0 commit comments