@@ -82,11 +82,23 @@ def val(eval_set, model, encoder_dim, device, opt, config, writer, epoch_num=0,
82
82
tqdm .write ('====> Calculating recall @ N' )
83
83
n_values = [1 , 5 , 10 , 20 , 50 , 100 ]
84
84
85
- _ , predictions = faiss_index .search (qFeat , max (n_values ))
86
-
87
85
# for each query get those within threshold distance
88
86
gt = eval_set .all_pos_indices
89
87
88
+ # any combination of mapillary cities will work as a val set
89
+ qEndPosTot = 0
90
+ dbEndPosTot = 0
91
+ for cityNum , (qEndPos , dbEndPos ) in enumerate (zip (eval_set .qEndPosList , eval_set .dbEndPosList )):
92
+ faiss_index = faiss .IndexFlatL2 (pool_size )
93
+ faiss_index .add (dbFeat [dbEndPosTot :dbEndPosTot + dbEndPos , :])
94
+ _ , preds = faiss_index .search (qFeat [qEndPosTot :qEndPosTot + qEndPos , :], max (n_values ))
95
+ if cityNum == 0 :
96
+ predictions = preds
97
+ else :
98
+ predictions = np .vstack ((predictions , preds ))
99
+ qEndPosTot += qEndPos
100
+ dbEndPosTot += dbEndPos
101
+
90
102
correct_at_n = np .zeros (len (n_values ))
91
103
# TODO can we do this on the matrix in one go?
92
104
for qIx , pred in enumerate (predictions ):
@@ -104,4 +116,4 @@ def val(eval_set, model, encoder_dim, device, opt, config, writer, epoch_num=0,
104
116
if write_tboard :
105
117
writer .add_scalar ('Val/Recall@' + str (n ), recall_at_n [i ], epoch_num )
106
118
107
- return all_recalls
119
+ return all_recalls
0 commit comments