Skip to content

Commit 9261ea2

Browse files
committed
turn back functions in GCN.py
1 parent c4db85d commit 9261ea2

File tree

3 files changed

+28
-11
lines changed

3 files changed

+28
-11
lines changed

code/GCN.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,7 @@ def train_model_diversity(
241241
model,
242242
train_loader,
243243
valid_loader,
244-
discrete_diversity_matrix_train,
245-
discrete_diversity_matrix_valid,
244+
discrete_diversity_matrix,
246245
optimizer,
247246
criterion,
248247
num_epochs,
@@ -271,7 +270,7 @@ def train_model_diversity(
271270
anchor = model(features, edge_index)
272271

273272
positive_index, negative_index = get_positive_and_negative(
274-
discrete_diversity_matrix_train, i
273+
discrete_diversity_matrix, index, train_loader.dataset
275274
)
276275
if positive_index is None or negative_index is None:
277276
continue
@@ -318,7 +317,7 @@ def train_model_diversity(
318317
anchor = model(features, edge_index)
319318

320319
positive_index, negative_index = get_positive_and_negative(
321-
discrete_diversity_matrix_valid, i
320+
discrete_diversity_matrix, index, valid_loader.dataset
322321
)
323322
if positive_index is None or negative_index is None:
324323
continue
@@ -478,8 +477,28 @@ def train_model_accuracy(
478477
return train_losses, valid_losses
479478

480479

481-
def get_positive_and_negative(diversity_matrix, index):
482-
positive = np.where(diversity_matrix[index] == 1)[0]
483-
negative = np.where(diversity_matrix[index] == 0)[0]
480+
def get_positive_and_negative(diversity_matrix, index, dataset=None):
481+
positive = np.where(
482+
(diversity_matrix[index, :] == 1) & (np.arange(len(diversity_matrix)) != index)
483+
)[0].tolist()
484+
negative = np.where(diversity_matrix[index, :] == -1)[0].tolist()
485+
486+
if dataset is not None:
487+
appropriate_indexes = [dataset[i][2] for i in range(len(dataset))]
488+
489+
positive = [
490+
appropriate_indexes.index(idx)
491+
for idx in positive
492+
if idx in appropriate_indexes
493+
]
494+
negative = [
495+
appropriate_indexes.index(idx)
496+
for idx in negative
497+
if idx in appropriate_indexes
498+
]
499+
500+
if not positive or not negative:
501+
print("Both positive and negative samples are empty!")
502+
return None, None
484503

485504
return np.random.choice(positive), np.random.choice(negative)

code/create_GCN_dataset.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,9 @@ def apply_argmax_to_predictions(data):
5353
save_dicts_as_json(first_arch_dicts, output_dir)
5454

5555
apply_argmax_to_predictions(first_arch_dicts)
56-
output_dir = "tmp_dataset"
57-
save_dicts_as_json(first_arch_dicts, output_dir)
5856

5957
second_arch_dicts = load_json_from_directory("second_dataset")
58+
second_arch_dicts.extend(first_arch_dicts)
6059

61-
first_arch_dicts.extend(second_arch_dicts)
6260
output_dir = "third_dataset"
63-
save_dicts_as_json(first_arch_dicts, output_dir)
61+
save_dicts_as_json(second_arch_dicts, output_dir)

code/dependecies.zip

5.6 KB
Binary file not shown.

0 commit comments

Comments
 (0)