Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/main/lib/elastic_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_presto_request_response(modality, callback_url, task):

def requires_encoding(obj):
for model_key in obj.get("models", []):
if not obj.get('model_'+model_key):
if model_key != "elasticsearch" and not obj.get('model_'+model_key):
return True
return False

Expand Down
31 changes: 17 additions & 14 deletions app/main/lib/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,17 @@ def generate_matches(context):
matches = []
clause_count = 0
for key in context:
if isinstance(context[key], list):
clause_count += len(context[key])
matches.append({
'query_string': { 'query': str.join(" OR ", [f"context.{key}: {v}" for v in context[key]])}
})
else:
clause_count += 1
matches.append({
'match': { 'context.' + key: context[key] }
})
if key not in ["project_media_id", "has_custom_id", "field"]:
if isinstance(context[key], list):
clause_count += len(context[key])
matches.append({
'query_string': { 'query': str.join(" OR ", [f"context.{key}: {v}" for v in context[key]])}
})
else:
clause_count += 1
matches.append({
'match': { 'context.' + key: context[key] }
})
return matches, clause_count

def truncate_query(query, clause_count):
Expand Down Expand Up @@ -112,12 +113,14 @@ def get_by_doc_id(doc_id):
return response['_source']

def store_document(body, doc_id, language=None):
for field in ["per_model_threshold", "threshold", "model", "confirmed", "limit", "requires_callback"]:
body.pop(field, None)
storable_doc = {}
for k,v in body.items():
if k not in ["per_model_threshold", "threshold", "model", "confirmed", "limit", "requires_callback"]:
storable_doc[k] = v
indices = [app.config['ELASTICSEARCH_SIMILARITY']]
# 'auto' indicates we should try to guess the appropriate language
if language == 'auto':
text = body['content']
text = storable_doc['content']
language = LangidProvider.langid(text)['result']['language']
if language not in SUPPORTED_LANGUAGES:
app.logger.warning('Detected language {} is not supported'.format(language))
Expand All @@ -129,7 +132,7 @@ def store_document(body, doc_id, language=None):

results = []
for index in indices:
index_result = update_or_create_document(body, doc_id, index)
index_result = update_or_create_document(storable_doc, doc_id, index)
results.append(index_result)
if index_result['result'] not in ['created', 'updated', 'noop']:
app.logger.warning('Problem adding document to ES index for language {0}: {1}'.format(language, index_result))
Expand Down
19 changes: 12 additions & 7 deletions app/main/lib/text_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def async_search_text(task, modality):
return elastic_crud.get_async_presto_response(task, "text", modality)

def fill_in_openai_embeddings(document):
for model_key in document.pop("models", []):
for model_key in document.get("models", []):
if model_key != "elasticsearch" and model_key[:len(PREFIX_OPENAI)] == PREFIX_OPENAI:
document['vector_'+model_key] = retrieve_openai_embeddings(document['content'], model_key)
document['model_'+model_key] = 1
Expand Down Expand Up @@ -76,7 +76,7 @@ def search_text(search_params, use_document_vectors=False):
if model_key != "elasticsearch":
search_params.pop("model", None)
if use_document_vectors:
vector_for_search = search_params[model_key+"-tokens"]
vector_for_search = search_params["vector_"+model_key]
else:
vector_for_search = None
result = search_text_by_model(dict(**search_params, **{'model': model_key}), vector_for_search)
Expand Down Expand Up @@ -175,6 +175,9 @@ def insert_model_into_response(hits, model_key):
hit["_source"]["model"] = model_key
return hits

def return_sources(results):
return [dict(**r["_source"], **{"index": r["_index"], "score": r["_score"]}) for r in results]

def strip_vectors(results):
for result in results:
vector_keys = [key for key in result["_source"].keys() if key[:7] == "vector_"]
Expand Down Expand Up @@ -260,11 +263,13 @@ def search_text_by_model(search_params, vector_for_search):
body=body,
index=search_indices
)
response = strip_vectors(
restrict_results(
insert_model_into_response(result['hits']['hits'], model_key),
search_params,
model_key
response = return_sources(
strip_vectors(
restrict_results(
insert_model_into_response(result['hits']['hits'], model_key),
search_params,
model_key
)
)
)
return {
Expand Down
12 changes: 6 additions & 6 deletions app/test/test_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,10 @@ def test_elasticsearch_performs_correct_fuzzy_search(self):
post_response = self.client.post('/text/similarity/search/', data=json.dumps(lookup), content_type='application/json')
lookup["fuzzy"] = True
post_response_fuzzy = self.client.post('/text/similarity/search/', data=json.dumps(lookup), content_type='application/json')
self.assertGreater(json.loads(post_response_fuzzy.data.decode())["result"][0]["_score"], json.loads(post_response.data.decode())["result"][0]["_score"])
self.assertGreater(json.loads(post_response_fuzzy.data.decode())["result"][0]["score"], json.loads(post_response.data.decode())["result"][0]["score"])
lookup["fuzzy"] = False
post_response_fuzzy = self.client.post('/text/similarity/search/', data=json.dumps(lookup), content_type='application/json')
self.assertEqual(json.loads(post_response_fuzzy.data.decode())["result"][0]["_score"], json.loads(post_response.data.decode())["result"][0]["_score"])
self.assertEqual(json.loads(post_response_fuzzy.data.decode())["result"][0]["score"], json.loads(post_response.data.decode())["result"][0]["score"])

def test_elasticsearch_update_text(self):
with self.client:
Expand Down Expand Up @@ -455,7 +455,7 @@ def test_model_similarity(self):
)
result = json.loads(response.data.decode())
self.assertEqual(1, len(result['result']))
similarity = result['result'][0]['_score']
similarity = result['result'][0]['score']
self.assertGreater(similarity, 0.7)

response = self.client.post(
Expand Down Expand Up @@ -487,7 +487,7 @@ def test_model_similarity(self):
)
result = json.loads(response.data.decode())
self.assertEqual(1, len(result['result']))
similarity = result['result'][0]['_score']
similarity = result['result'][0]['score']
self.assertGreater(similarity, 0.7)

response = self.client.post(
Expand All @@ -501,7 +501,7 @@ def test_model_similarity(self):
)
result = json.loads(response.data.decode())
self.assertEqual(1, len(result['result']))
similarity = result['result'][0]['_score']
similarity = result['result'][0]['score']
self.assertGreater(similarity, 0.7)

def test_wrong_model_key(self):
Expand Down Expand Up @@ -599,7 +599,7 @@ def test_min_es_search(self):
result = json.loads(response.data.decode())

self.assertEqual(1, len(result['result']))
data['min_es_score']=10+result['result'][0]['_score']
data['min_es_score']=10+result['result'][0]['score']

response = self.client.post(
'/text/similarity/search/',
Expand Down
6 changes: 3 additions & 3 deletions app/test/test_similarity_lang_analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_all_analyzers(self):
content_type='application/json'
)
result = json.loads(response.data.decode())
self.assertTrue(app.config['ELASTICSEARCH_SIMILARITY']+"_"+example['language'] in [e['_index'] for e in result['result']])
self.assertTrue(app.config['ELASTICSEARCH_SIMILARITY']+"_"+example['language'] in [e['index'] for e in result['result']])

def test_auto_language_id(self):
# language examples as input to language classifier
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_auto_language_id(self):
index_alias = app.config['ELASTICSEARCH_SIMILARITY']
if expected_lang is not None:
index_alias = app.config['ELASTICSEARCH_SIMILARITY']+"_"+expected_lang
self.assertTrue(index_alias in [e['_index'] for e in result['result']])
self.assertTrue(index_alias in [e['index'] for e in result['result']])

def test_auto_language_query(self):
# language examples as input to language classifier
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_auto_language_query(self):
index_alias = app.config['ELASTICSEARCH_SIMILARITY']
if expected_lang is not None:
index_alias = app.config['ELASTICSEARCH_SIMILARITY']+"_"+expected_lang
self.assertTrue(index_alias in [e['_index'] for e in result['result']])
self.assertTrue(index_alias in [e['index'] for e in result['result']])


if __name__ == '__main__':
Expand Down
Loading