@@ -241,7 +241,8 @@ def train_model_diversity(
241
241
model ,
242
242
train_loader ,
243
243
valid_loader ,
244
- discrete_diversity_matrix ,
244
+ discrete_diversity_matrix_train ,
245
+ discrete_diversity_matrix_valid ,
245
246
optimizer ,
246
247
criterion ,
247
248
num_epochs ,
@@ -270,7 +271,7 @@ def train_model_diversity(
270
271
anchor = model (features , edge_index )
271
272
272
273
positive_index , negative_index = get_positive_and_negative (
273
- discrete_diversity_matrix , index , train_loader . dataset
274
+ discrete_diversity_matrix_train , i
274
275
)
275
276
if positive_index is None or negative_index is None :
276
277
continue
@@ -317,7 +318,7 @@ def train_model_diversity(
317
318
anchor = model (features , edge_index )
318
319
319
320
positive_index , negative_index = get_positive_and_negative (
320
- discrete_diversity_matrix , index , valid_loader . dataset
321
+ discrete_diversity_matrix_valid , i
321
322
)
322
323
if positive_index is None or negative_index is None :
323
324
continue
@@ -477,28 +478,8 @@ def train_model_accuracy(
477
478
return train_losses , valid_losses
478
479
479
480
480
- def get_positive_and_negative (diversity_matrix , index , dataset = None ):
481
- positive = np .where (
482
- (diversity_matrix [index , :] == 1 ) & (np .arange (len (diversity_matrix )) != index )
483
- )[0 ].tolist ()
484
- negative = np .where (diversity_matrix [index , :] == - 1 )[0 ].tolist ()
485
-
486
- if dataset is not None :
487
- appropriate_indexes = [dataset [i ][2 ] for i in range (len (dataset ))]
488
-
489
- positive = [
490
- appropriate_indexes .index (idx )
491
- for idx in positive
492
- if idx in appropriate_indexes
493
- ]
494
- negative = [
495
- appropriate_indexes .index (idx )
496
- for idx in negative
497
- if idx in appropriate_indexes
498
- ]
499
-
500
- if not positive or not negative :
501
- print ("Both positive and negative samples are empty!" )
502
- return None , None
481
+ def get_positive_and_negative (diversity_matrix , index ):
482
+ positive = np .where (diversity_matrix [index ] == 1 )[0 ]
483
+ negative = np .where (diversity_matrix [index ] == 0 )[0 ]
503
484
504
485
return np .random .choice (positive ), np .random .choice (negative )
0 commit comments