Skip to content

Commit 287d9b7

Browse files
committed
now for training diversity model we use train and valid diversity matrices
1 parent ab08b84 commit 287d9b7

File tree

2 files changed

+456
-442
lines changed

2 files changed

+456
-442
lines changed

code/GCN.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ def train_model_diversity(
241241
model,
242242
train_loader,
243243
valid_loader,
244-
discrete_diversity_matrix,
244+
discrete_diversity_matrix_train,
245+
discrete_diversity_matrix_valid,
245246
optimizer,
246247
criterion,
247248
num_epochs,
@@ -270,7 +271,7 @@ def train_model_diversity(
270271
anchor = model(features, edge_index)
271272

272273
positive_index, negative_index = get_positive_and_negative(
273-
discrete_diversity_matrix, index, train_loader.dataset
274+
discrete_diversity_matrix_train, i
274275
)
275276
if positive_index is None or negative_index is None:
276277
continue
@@ -317,7 +318,7 @@ def train_model_diversity(
317318
anchor = model(features, edge_index)
318319

319320
positive_index, negative_index = get_positive_and_negative(
320-
discrete_diversity_matrix, index, valid_loader.dataset
321+
discrete_diversity_matrix_valid, i
321322
)
322323
if positive_index is None or negative_index is None:
323324
continue
@@ -477,28 +478,8 @@ def train_model_accuracy(
477478
return train_losses, valid_losses
478479

479480

480-
def get_positive_and_negative(diversity_matrix, index, dataset=None):
481-
positive = np.where(
482-
(diversity_matrix[index, :] == 1) & (np.arange(len(diversity_matrix)) != index)
483-
)[0].tolist()
484-
negative = np.where(diversity_matrix[index, :] == -1)[0].tolist()
485-
486-
if dataset is not None:
487-
appropriate_indexes = [dataset[i][2] for i in range(len(dataset))]
488-
489-
positive = [
490-
appropriate_indexes.index(idx)
491-
for idx in positive
492-
if idx in appropriate_indexes
493-
]
494-
negative = [
495-
appropriate_indexes.index(idx)
496-
for idx in negative
497-
if idx in appropriate_indexes
498-
]
499-
500-
if not positive or not negative:
501-
print("Both positive and negative samples are empty!")
502-
return None, None
481+
def get_positive_and_negative(diversity_matrix, index):
482+
positive = np.where(diversity_matrix[index] == 1)[0]
483+
negative = np.where(diversity_matrix[index] == 0)[0]
503484

504485
return np.random.choice(positive), np.random.choice(negative)

0 commit comments

Comments
 (0)