diff --git a/app/main/lib/text_similarity.py b/app/main/lib/text_similarity.py index ff37f6ce..1a62afa8 100644 --- a/app/main/lib/text_similarity.py +++ b/app/main/lib/text_similarity.py @@ -201,13 +201,24 @@ def strip_vectors(results): return results def restrict_results(results, search_params, model_key): + """ + min_es_score is the minimum elasticsearch score needed to include a result. + This is applied after the results are retrieved from Elasticsearch. + The threshold parameter is in the range [0,1] and used when making the query. + The min_es_score is used after results are retrieved and applied to the + Elasticsearch scores, which are in the range [0, +Inf]. + """ out_results = [] - try: - min_es_score = float(search_params.get("min_es_score")) - except (ValueError, TypeError) as e: - app.logger.info(f"search_params failed on min_es_score for {search_params}, raised error as {e}") - min_es_score = None + min_es_score = search_params.get("min_es_score") + if min_es_score is None: + min_es_score = 0.0 + app.logger.warning(f"min_es_score is missing or None, defaulting to {min_es_score}") if min_es_score is not None and model_key == "elasticsearch": + try: + min_es_score = float(min_es_score) + except ValueError as e: + app.logger.error(f"Invalid min_es_score '{min_es_score}': {e}") + raise(e) for result in results: if "_score" in result and min_es_score < result["_score"]: out_results.append(result) diff --git a/app/test/test_similarity.py b/app/test/test_similarity.py index c149af19..cec2787d 100644 --- a/app/test/test_similarity.py +++ b/app/test/test_similarity.py @@ -589,8 +589,9 @@ def test_model_similarity_with_vector(self): def test_min_es_search(self): + # confirm that min es filtering works with self.client: - data={ + data = { 'text':'min_es_score', 'models':['elasticsearch'], } @@ -609,7 +610,7 @@ def test_min_es_search(self): result = json.loads(response.data.decode()) self.assertEqual(1, len(result['result'])) - data['min_es_score']=10+result['result'][0]['score'] + data['min_es_score'] = 10+result['result'][0]['score'] response = self.client.post( '/text/similarity/search/', @@ -619,5 +620,29 @@ def test_min_es_search(self): result = json.loads(response.data.decode()) self.assertEqual(0, len(result['result'])) + # confirm that min_es_score missing or None: set to zero with warning + data['min_es_score'] = None + response2 = self.client.post( + '/text/similarity/search/', + data=json.dumps(data), + content_type='application/json' + ) + result2 = json.loads(response2.data.decode()) + self.assertEqual(1, len(result2['result'])) + + # confirm that min_es_score cannot parse as float: log error and raise exception? + # we won't see exception here, but result should not be sucess + data['min_es_score'] = 'fifty' + response = self.client.post( + '/text/similarity/search/', + data=json.dumps(data), + content_type='application/json' + ) + assert response.status_code == 500, f"status code was{response.status_code}" + result = json.loads(response.data.decode()) + self.assertIsNone(result.get('success')), f"result was {result}" + # TODO: is excption being swollowed? Need to confirm logging + + if __name__ == '__main__': unittest.main() \ No newline at end of file