-
Notifications
You must be signed in to change notification settings - Fork 12
Description
In the code for getting entity neighbors, as_head_neighbors and as_tail_neighbors are both gotten in the same way. However, for as_tail_neighbors, the result are [source entity, relation, head entity] which are wrong triplets. I think the code below is correct.
def get_entity_neighbors(traget_entity, max_triplet):
as_head_neighbors = get_onestep_neighbors(graph, traget_entity, True, max_triplet // 2)
as_tail_neighbors = get_onestep_neighbors(reverse_graph, traget_entity, False, max_triplet // 2)
all_triplet = as_head_neighbors + as_tail_neighbors
return all_triplet
def get_onestep_neighbors(graph, source, forward, sample_num):
triplet = []
try:
nei = list(graph[source].keys())
# nei = random.sample(graph[source].keys(), sample_num)
if forward:
triplet = [tuple((source, graph[source][nei[i]], nei[i])) for i in range(len(nei))]
else:
triplet = [tuple((nei[i], graph[source][nei[i]], source)) for i in range(len(nei))]
except KeyError:
pass
except ValueError:
nei = list(graph[source].keys())
triplet = [tuple((source, graph[source][nei[i]], nei[i])) for i in range(len(nei))]
return triplet