Skip to content

fix: Use singleton PostgresDBClient (Sqlalchemy engine) #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/promptfoo-googlesheet-evaluation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ jobs:
EVAL_OUTPUT_FILE="/tmp/promptfoo-output.txt"

if [ -n "$PROMPTFOO_API_KEY" ]; then
promptfoo eval --max-concurrency 1 --config "/tmp/promptfooconfig.processed.yaml" --share --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}"
promptfoo eval --config "/tmp/promptfooconfig.processed.yaml" --share --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}"
else
promptfoo eval --max-concurrency 1 --config "/tmp/promptfooconfig.processed.yaml" --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}"
promptfoo eval --config "/tmp/promptfooconfig.processed.yaml" --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}"
Comment on lines -121 to +123
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverts #317, so it should default back to 4 threads

fi

if [ -f "${OUTPUT_JSON_FILE}" ]; then
Expand Down
11 changes: 10 additions & 1 deletion app/src/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,17 @@ class AppConfig(PydanticBaseEnvConfig):
# If set, used instead of LITERAL_API_KEY for API
literal_api_key_for_api: str | None = None

@cached_property
def db_client(self) -> db.PostgresDBClient:
return db.PostgresDBClient()
Comment on lines +45 to +47
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basically creates a singleton PostgresDBClient instance, which holds the Sqlalchemy engine.

“the engine is thread safe yes. individual Connection objects are not. we try to describe this at Working with Engines and Connections — SQLAlchemy 2.0 Documentation


# def db_session(self) -> db.Session:
# import pdb
# pdb.set_trace()
# return db.PostgresDBClient().get_session()

def db_session(self) -> db.Session:
return db.PostgresDBClient().get_session()
return self.db_client.get_session()
Comment on lines -46 to +50
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new session uses an available connection (from the connection pool), so threads (e.g., those created from API calls) will not share a connection.

Our pool size is 20 – what happens when max sqlalchemy pool size is reached?

Google’s AI states:

Requests are queued: Any new requests for a database connection are placed in a queue, waiting for a connection to become available.

Timeout begins: SQLAlchemy starts a timeout period (defaulting to 30 seconds, but configurable) to see if a connection is released back into the pool.

Connection timeout error: If a connection doesn't become available within the timeout period, an exception (e.g., TimeoutError) is thrown, indicating a connection timeout.

SQLAlchemy connection pooling, what are checked out connections? : “If all the connections are simultaneously checked out then you can expect an error (there will be a timeout period during which SQLAlchemy waits to see if a connection gets freed up; this is also configurable).”


@cached_property
def embedding_model(self) -> EmbeddingModel:
Expand Down
5 changes: 5 additions & 0 deletions app/src/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def get_models() -> dict[str, str]:
models |= {"OpenAI GPT-4o": "gpt-4o"}
if "ANTHROPIC_API_KEY" in os.environ:
models |= {"Anthropic Claude 3.5 Sonnet": "claude-3-5-sonnet-20240620"}
if "GEMINI_API_KEY" in os.environ:
models |= {
"Google Gemini 1.5 Pro": "gemini/gemini-1.5-pro",
"Google Gemini 2.5 Pro": "gemini/gemini-2.5-pro-preview-06-05",
}
if _has_aws_access():
# If you get "You don't have access to the model with the specified model ID." error,
# remember to request access to Bedrock models ...aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess
Expand Down
4 changes: 2 additions & 2 deletions app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def enable_factory_create(monkeypatch, db_session) -> db.Session:


@pytest.fixture
def app_config(monkeypatch, db_session):
monkeypatch.setattr(AppConfig, "db_session", lambda _self: db_session)
def app_config(monkeypatch, db_client: db.DBClient):
monkeypatch.setattr(AppConfig, "db_session", lambda _self: db_client.get_session())
monkeypatch.setattr(AppConfig, "embedding_model", MockEmbeddingModel())


Expand Down
16 changes: 9 additions & 7 deletions app/tests/src/db/test_pg_dump_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def populated_db(db_session):
ChatMessageFactory.create_batch(4, session=user_session)


def test_backup_and_restore_db(enable_factory_create, populated_db, caplog):
def test_backup_and_restore_db(enable_factory_create, populated_db, app_config, caplog):
with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname:
dumpfile = f"{tmpdirname}/db.dump"
pg_dump_util.backup_db(dumpfile, "local")
Expand All @@ -43,7 +43,7 @@ def test_backup_and_restore_db(enable_factory_create, populated_db, caplog):
assert f"DB data restored from {dumpfile!r}" in caplog.messages


def test_restore_db_without_truncating(caplog):
def test_restore_db_without_truncating(app_config, caplog):
with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname:
dumpfile = f"{tmpdirname}/db.dump"
with open(dumpfile, "wb"): # Create an empty file
Expand All @@ -54,7 +54,7 @@ def test_restore_db_without_truncating(caplog):
)


def test_backup_db__file_exists(caplog):
def test_backup_db__file_exists(app_config, caplog):
with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname:
dumpfile = f"{tmpdirname}/db.dump"
with open(dumpfile, "wb"): # Create an empty file
Expand All @@ -68,7 +68,7 @@ def test_backup_db__file_exists(caplog):
assert f"DB data dumped to '{tmpdirname}/db.dump'" not in caplog.messages


def test_backup_db__dump_failure(caplog, monkeypatch):
def test_backup_db__dump_failure(app_config, caplog, monkeypatch):
monkeypatch.setattr(pg_dump_util, "_pg_dump", lambda *args: False)
with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname:
dumpfile = f"{tmpdirname}/db.dump"
Expand All @@ -77,7 +77,7 @@ def test_backup_db__dump_failure(caplog, monkeypatch):
assert f"Failed to dump DB data to '{tmpdirname}/db.dump'" in caplog.messages


def test_backup_db__truncate_failure(caplog, monkeypatch):
def test_backup_db__truncate_failure(app_config, caplog, monkeypatch):
monkeypatch.setattr(pg_dump_util, "_truncate_db_tables", lambda *args: False)
with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname:
dumpfile = f"{tmpdirname}/db.dump"
Expand All @@ -95,7 +95,9 @@ def mock_s3_dev_bucket(mock_s3):
yield bucket


def test_backup_db_for_dev(enable_factory_create, populated_db, caplog, mock_s3_dev_bucket):
def test_backup_db_for_dev(
enable_factory_create, populated_db, app_config, caplog, mock_s3_dev_bucket
):
with caplog.at_level(logging.INFO):
dumpfile = "dev_db.dump"
pg_dump_util.backup_db(dumpfile, "dev")
Expand All @@ -109,7 +111,7 @@ def test_backup_db_for_dev(enable_factory_create, populated_db, caplog, mock_s3_
)


def test_restore_db_failure(caplog):
def test_restore_db_failure(app_config, caplog):
with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname:
dumpfile = f"{tmpdirname}/db.dump"

Expand Down
4 changes: 2 additions & 2 deletions app/tests/src/evaluation/cli/test_cli_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_invalid_arguments():


@pytest.mark.integration
def test_main_integration(temp_output_dir):
def test_main_integration(temp_output_dir, app_config):
"""Integration test with minimal test data."""
with mock.patch(
"sys.argv",
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_main_integration(temp_output_dir):
assert mock_qa_pairs_path.exists()


def test_error_handling_no_documents():
def test_error_handling_no_documents(app_config):
"""Test handling of 'No documents found' error."""
with mock.patch("sys.argv", ["generate.py"]):
with mock.patch("src.evaluation.qa_generation.runner.run_generation") as mock_run:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_qa_storage(qa_pairs, tmp_path):


def test_run_generation_basic(
mock_completion_response, tmp_path, enable_factory_create, db_session
mock_completion_response, tmp_path, enable_factory_create, db_session, app_config
):
"""Test basic run_generation functionality."""
output_dir = tmp_path / "qa_output"
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_run_generation_basic(


def test_run_generation_with_dataset_filter(
mock_completion_response, tmp_path, enable_factory_create, db_session
mock_completion_response, tmp_path, enable_factory_create, db_session, app_config
):
"""Test run_generation with dataset filter."""
output_dir = tmp_path / "qa_output"
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_save_qa_pairs(qa_pairs, tmp_path):


def test_run_generation_with_version(
mock_completion_response, tmp_path, enable_factory_create, db_session
mock_completion_response, tmp_path, enable_factory_create, db_session, app_config
):
"""Test run_generation with version parameter."""
output_dir = tmp_path / "qa_output"
Expand Down Expand Up @@ -301,7 +301,9 @@ def test_run_generation_with_version(
assert pair.version.llm_model == version.llm_model


def test_run_generation_invalid_sample_fraction(tmp_path, enable_factory_create, db_session):
def test_run_generation_invalid_sample_fraction(
tmp_path, enable_factory_create, db_session, app_config
):
"""Test run_generation with invalid sample fraction values."""
output_dir = tmp_path / "qa_output"
llm_model = "gpt-4o-mini"
Expand Down
4 changes: 2 additions & 2 deletions app/tests/src/test_chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def no_literalai_data_layer(monkeypatch):


@pytest.fixture
def async_client(no_literalai_data_layer, db_session):
def async_client(no_literalai_data_layer, db_session, app_config):
"""
The typical FastAPI TestClient creates its own event loop to handle requests,
which led to issues when testing code that relies on asynchronous operations
Expand Down Expand Up @@ -108,7 +108,7 @@ async def test_api_engines(async_client, db_session):


@pytest.mark.asyncio
async def test_api_engines__dbsession_contextvar(async_client, monkeypatch, db_session):
async def test_api_engines__dbsession_contextvar(async_client, monkeypatch, db_session, app_config):
event = asyncio.Event()
db_sessions = []
orig_init_chat_session = chat_api._init_chat_session
Expand Down
2 changes: 2 additions & 0 deletions app/tests/src/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def test_get_models(monkeypatch):
monkeypatch.delenv("AWS_ACCESS_KEY_ID")
if "OPENAI_API_KEY" in os.environ:
monkeypatch.delenv("OPENAI_API_KEY")
if "GEMINI_API_KEY" in os.environ:
monkeypatch.delenv("GEMINI_API_KEY")
if "OLLAMA_HOST" in os.environ:
monkeypatch.delenv("OLLAMA_HOST")
assert get_models() == {}
Expand Down
8 changes: 4 additions & 4 deletions app/tests/src/test_ingest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test__drop_existing_dataset(db_session, enable_factory_create):
)


def test_process_and_ingest_sys_args_requires_four_args(caplog):
def test_process_and_ingest_sys_args_requires_four_args(app_config, caplog):
logger = logging.getLogger(__name__)
ingest = Mock()

Expand All @@ -65,7 +65,7 @@ def test_process_and_ingest_sys_args_requires_four_args(caplog):
assert not ingest.called


def test_process_and_ingest_sys_args_calls_ingest(caplog):
def test_process_and_ingest_sys_args_calls_ingest(app_config, caplog):
logger = logging.getLogger(__name__)
ingest = Mock()

Expand Down Expand Up @@ -95,7 +95,7 @@ def test_process_and_ingest_sys_args_calls_ingest(caplog):


def test_process_and_ingest_sys_args_drops_existing_dataset(
db_session, caplog, enable_factory_create
db_session, app_config, caplog, enable_factory_create
):
db_session.execute(delete(Document))
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_process_and_ingest_sys_args_drops_existing_dataset(
assert db_session.execute(select(Document).where(Document.dataset == "other dataset")).one()


def test_process_and_ingest_sys_args_resume(db_session, caplog, enable_factory_create):
def test_process_and_ingest_sys_args_resume(db_session, app_config, caplog, enable_factory_create):
db_session.execute(delete(Document))
logger = logging.getLogger(__name__)
ingest = Mock()
Expand Down
14 changes: 7 additions & 7 deletions app/tests/src/test_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def _create_chunks(document=None):
]


def _format_retrieval_results(retrieval_results):
return [chunk_with_score.chunk for chunk_with_score in retrieval_results]
def _chunk_ids(retrieval_results):
return [chunk_with_score.chunk.id for chunk_with_score in retrieval_results]


def test_retrieve__with_empty_filter(app_config, db_session, enable_factory_create):
Expand All @@ -33,7 +33,7 @@ def test_retrieve__with_empty_filter(app_config, db_session, enable_factory_crea
results = retrieve_with_scores(
"Very tiny words.", retrieval_k=2, retrieval_k_min_score=0.0, datasets=[]
)
assert _format_retrieval_results(results) == [short_chunk, medium_chunk]
assert _chunk_ids(results) == [short_chunk.id, medium_chunk.id]


def test_retrieve__with_unknown_filter(app_config, db_session, enable_factory_create):
Expand All @@ -59,7 +59,7 @@ def test_retrieve__with_dataset_filter(app_config, db_session, enable_factory_cr
retrieval_k_min_score=0.0,
datasets=["SNAP"],
)
assert _format_retrieval_results(results) == [snap_short_chunk, snap_medium_chunk]
assert _chunk_ids(results) == [snap_short_chunk.id, snap_medium_chunk.id]


def test_retrieve__with_other_filters(app_config, db_session, enable_factory_create):
Expand All @@ -76,7 +76,7 @@ def test_retrieve__with_other_filters(app_config, db_session, enable_factory_cre
programs=["SNAP"],
regions=["MI"],
)
assert _format_retrieval_results(results) == [snap_short_chunk, snap_medium_chunk]
assert _chunk_ids(results) == [snap_short_chunk.id, snap_medium_chunk.id]


def test_retrieve_with_scores(app_config, db_session, enable_factory_create):
Expand All @@ -86,7 +86,7 @@ def test_retrieve_with_scores(app_config, db_session, enable_factory_create):
results = retrieve_with_scores("Very tiny words.", retrieval_k=2, retrieval_k_min_score=0.0)

assert len(results) == 2
assert results[0].chunk == short_chunk
assert results[0].chunk.id == short_chunk.id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for these changes? Because the db_session from the test fixture (the second parameter to test_retrieve_with_scores) is different from the db_session generated by retrieve_with_scores?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. There's some identifier in the chunk instances that is specific to the db_session, so those identifiers are different and cause the assertion to fail. Otherwise the chunks are identical.

Copy link
Contributor

@KevinJBoyer KevinJBoyer Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. I think the identifier here is the literal address in memory: under the hood, SQLAlchemy ensures that if you retrieve the same row back in the same session, it creates only a single instance of the object in memory so you can do things like (psuedocode):

with some_session:
    # in this call SQLAlchemy creates a new instance of the Chunk class and returns that
    chunk_1 = some_session.select(chunk_id="abc").first()
    
    # in this call SQLAlchemy will recognize that it already has an instance of the Chunk class for this row, so it will return that Chunk
    chunk_2 = some_session.select(chunk_id="abc").first()
    assert chunk_1 == chunk_2 # these are literally the same object in memory

assert results[0].score == 0.7071067690849304
assert results[1].chunk == medium_chunk
assert results[1].chunk.id == medium_chunk.id
assert results[1].score == 0.25881901383399963
4 changes: 4 additions & 0 deletions infra/app/app-config/env-config/environment_variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ locals {
manage_method = "manual"
secret_store_name = "/${var.app_name}-${var.environment}/OPENAI_API_KEY"
}
GEMINI_API_KEY = {
manage_method = "manual"
secret_store_name = "/${var.app_name}-${var.environment}/GEMINI_API_KEY"
}
LITERAL_API_KEY = {
manage_method = "manual"
secret_store_name = "/${var.app_name}-${var.environment}/LITERAL_API_KEY"
Expand Down