Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions app/main/lib/text_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,24 @@ def strip_vectors(results):
return results

def restrict_results(results, search_params, model_key):
"""
min_es_score is the minimum elasticsearch score needed to include a result.
This is applied after the results are retrieved from Elasticsearch.
The threshold parameter is in the range [0,1] and used when making the query.
The min_es_score is used after results are retrieved and applied to the
Elasticsearch scores, which are in the range [0, +Inf].
"""
out_results = []
try:
min_es_score = float(search_params.get("min_es_score"))
except (ValueError, TypeError) as e:
app.logger.info(f"search_params failed on min_es_score for {search_params}, raised error as {e}")
min_es_score = None
min_es_score = search_params.get("min_es_score")
if min_es_score is None:
min_es_score = 0.0
app.logger.warning(f"min_es_score is missing or None, defaulting to {min_es_score}")
if min_es_score is not None and model_key == "elasticsearch":
try:
min_es_score = float(min_es_score)
except ValueError as e:
app.logger.error(f"Invalid min_es_score '{min_es_score}': {e}")
raise(e)
for result in results:
if "_score" in result and min_es_score < result["_score"]:
out_results.append(result)
Expand Down
29 changes: 27 additions & 2 deletions app/test/test_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,9 @@ def test_model_similarity_with_vector(self):


def test_min_es_search(self):
# confirm that min es filtering works
with self.client:
data={
data = {
'text':'min_es_score',
'models':['elasticsearch'],
}
Expand All @@ -609,7 +610,7 @@ def test_min_es_search(self):
result = json.loads(response.data.decode())

self.assertEqual(1, len(result['result']))
data['min_es_score']=10+result['result'][0]['score']
data['min_es_score'] = 10+result['result'][0]['score']

response = self.client.post(
'/text/similarity/search/',
Expand All @@ -619,5 +620,29 @@ def test_min_es_search(self):
result = json.loads(response.data.decode())
self.assertEqual(0, len(result['result']))

# confirm that min_es_score missing or None: set to zero with warning
data['min_es_score'] = None
response2 = self.client.post(
'/text/similarity/search/',
data=json.dumps(data),
content_type='application/json'
)
result2 = json.loads(response2.data.decode())
self.assertEqual(1, len(result2['result']))

# confirm that min_es_score cannot parse as float: log error and raise exception?
# we won't see exception here, but result should not be sucess
data['min_es_score'] = 'fifty'
response = self.client.post(
'/text/similarity/search/',
data=json.dumps(data),
content_type='application/json'
)
assert response.status_code == 500, f"status code was{response.status_code}"
result = json.loads(response.data.decode())
self.assertIsNone(result.get('success')), f"result was {result}"
# TODO: is excption being swollowed? Need to confirm logging


if __name__ == '__main__':
unittest.main()