Skip to content

Commit b9224b3

Browse files
Merge pull request #44 from QVPR/issue_43_fix
Fixing val.py incorrect recalls issue #43
2 parents 6005b55 + f06d664 commit b9224b3

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

patchnetvlad/training_tools/msls.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def __init__(self, root_dir, cities='', nNeg=5, transform=None, mode='train', ta
9999
self.dbImages = []
100100
self.sideways = []
101101
self.night = []
102+
self.qEndPosList = []
103+
self.dbEndPosList = []
102104

103105
self.all_pos_indices = []
104106

@@ -186,6 +188,9 @@ def __init__(self, root_dir, cities='', nNeg=5, transform=None, mode='train', ta
186188
self.qImages.extend(qSeqKeys)
187189
self.dbImages.extend(dbSeqKeys)
188190

191+
self.qEndPosList.append(len(qSeqKeys))
192+
self.dbEndPosList.append(len(dbSeqKeys))
193+
189194
qData = qData.loc[unique_qSeqIdx]
190195
dbData = dbData.loc[unique_dbSeqIdx]
191196

patchnetvlad/training_tools/val.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,23 @@ def val(eval_set, model, encoder_dim, device, opt, config, writer, epoch_num=0,
8282
tqdm.write('====> Calculating recall @ N')
8383
n_values = [1, 5, 10, 20, 50, 100]
8484

85-
_, predictions = faiss_index.search(qFeat, max(n_values))
86-
8785
# for each query get those within threshold distance
8886
gt = eval_set.all_pos_indices
8987

88+
# any combination of mapillary cities will work as a val set
89+
qEndPosTot = 0
90+
dbEndPosTot = 0
91+
for cityNum, (qEndPos, dbEndPos) in enumerate(zip(eval_set.qEndPosList, eval_set.dbEndPosList)):
92+
faiss_index = faiss.IndexFlatL2(pool_size)
93+
faiss_index.add(dbFeat[dbEndPosTot:dbEndPosTot+dbEndPos, :])
94+
_, preds = faiss_index.search(qFeat[qEndPosTot:qEndPosTot+qEndPos, :], max(n_values))
95+
if cityNum == 0:
96+
predictions = preds
97+
else:
98+
predictions = np.vstack((predictions, preds))
99+
qEndPosTot += qEndPos
100+
dbEndPosTot += dbEndPos
101+
90102
correct_at_n = np.zeros(len(n_values))
91103
# TODO can we do this on the matrix in one go?
92104
for qIx, pred in enumerate(predictions):
@@ -104,4 +116,4 @@ def val(eval_set, model, encoder_dim, device, opt, config, writer, epoch_num=0,
104116
if write_tboard:
105117
writer.add_scalar('Val/Recall@' + str(n), recall_at_n[i], epoch_num)
106118

107-
return all_recalls
119+
return all_recalls

0 commit comments

Comments
 (0)