@@ -241,8 +241,7 @@ def train_model_diversity(
241
241
model ,
242
242
train_loader ,
243
243
valid_loader ,
244
- discrete_diversity_matrix_train ,
245
- discrete_diversity_matrix_valid ,
244
+ discrete_diversity_matrix ,
246
245
optimizer ,
247
246
criterion ,
248
247
num_epochs ,
@@ -271,7 +270,7 @@ def train_model_diversity(
271
270
anchor = model (features , edge_index )
272
271
273
272
positive_index , negative_index = get_positive_and_negative (
274
- discrete_diversity_matrix_train , i
273
+ discrete_diversity_matrix , index , train_loader . dataset
275
274
)
276
275
if positive_index is None or negative_index is None :
277
276
continue
@@ -318,7 +317,7 @@ def train_model_diversity(
318
317
anchor = model (features , edge_index )
319
318
320
319
positive_index , negative_index = get_positive_and_negative (
321
- discrete_diversity_matrix_valid , i
320
+ discrete_diversity_matrix , index , valid_loader . dataset
322
321
)
323
322
if positive_index is None or negative_index is None :
324
323
continue
@@ -478,8 +477,28 @@ def train_model_accuracy(
478
477
return train_losses , valid_losses
479
478
480
479
481
- def get_positive_and_negative (diversity_matrix , index ):
482
- positive = np .where (diversity_matrix [index ] == 1 )[0 ]
483
- negative = np .where (diversity_matrix [index ] == 0 )[0 ]
480
+ def get_positive_and_negative (diversity_matrix , index , dataset = None ):
481
+ positive = np .where (
482
+ (diversity_matrix [index , :] == 1 ) & (np .arange (len (diversity_matrix )) != index )
483
+ )[0 ].tolist ()
484
+ negative = np .where (diversity_matrix [index , :] == - 1 )[0 ].tolist ()
485
+
486
+ if dataset is not None :
487
+ appropriate_indexes = [dataset [i ][2 ] for i in range (len (dataset ))]
488
+
489
+ positive = [
490
+ appropriate_indexes .index (idx )
491
+ for idx in positive
492
+ if idx in appropriate_indexes
493
+ ]
494
+ negative = [
495
+ appropriate_indexes .index (idx )
496
+ for idx in negative
497
+ if idx in appropriate_indexes
498
+ ]
499
+
500
+ if not positive or not negative :
501
+ print ("Both positive and negative samples are empty!" )
502
+ return None , None
484
503
485
504
return np .random .choice (positive ), np .random .choice (negative )
0 commit comments