Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions app/main/lib/elastic_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flask import current_app as app
from app.main.lib.presto import Presto, PRESTO_MODEL_MAP
from app.main.lib.elasticsearch import store_document, get_by_doc_id

from app.main.lib.openai import PREFIX_OPENAI
def _after_log(retry_state):
app.logger.debug("Retrying image similarity...")

Expand Down Expand Up @@ -40,9 +40,12 @@ def get_presto_request_response(modality, callback_url, task):
assert isinstance(response["body"], dict), f"Bad body for {modality}, {callback_url}, {task} - response was {response}"
return response

def encodable_model(model_key, obj):
return model_key != "elasticsearch" and not obj.get('model_'+model_key) and model_key[:len(PREFIX_OPENAI)] != PREFIX_OPENAI

def requires_encoding(obj):
for model_key in obj.get("models", []):
if model_key != "elasticsearch" and not obj.get('model_'+model_key):
if encodable_model(model_key, obj):
return True
return False

Expand All @@ -55,7 +58,7 @@ def get_blocked_presto_response(task, model, modality):
if requires_encoding(obj):
blocked_results = []
for model_key in obj.get("models", []):
if model_key != "elasticsearch" and not obj.get('model_'+model_key):
if encodable_model(model_key, obj):
response = get_presto_request_response(model_key, callback_url, obj)
blocked_results.append({"model": model_key, "response": Presto.blocked_response(response, modality)})
# Warning: this is a blocking hold to wait until we get a response in
Expand All @@ -73,7 +76,7 @@ def get_async_presto_response(task, model, modality):
if requires_encoding(obj):
responses = []
for model_key in obj.get("models", []):
if model_key != "elasticsearch" and not obj.get('model_'+model_key):
if encodable_model(model_key, obj):
task["model"] = model_key
responses.append(get_presto_request_response(model_key, callback_url, task))
return responses, True
Expand Down
4 changes: 4 additions & 0 deletions app/test/test_elastic_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def test_requires_encoding(self):
obj = {'models': ['model1'], 'model_model1': 'encoded_data'}
self.assertFalse(requires_encoding(obj))

obj = {'models': ['openai-text-embedding-ada-002']}
self.assertFalse(requires_encoding(obj))


@patch('app.main.lib.elastic_crud.Presto.blocked_response')
@patch('app.main.lib.elastic_crud.Presto.send_request')
@patch('app.main.lib.elastic_crud.store_document')
Expand Down
Loading