diff --git a/.github/workflows/promptfoo-googlesheet-evaluation.yml b/.github/workflows/promptfoo-googlesheet-evaluation.yml index 232907c3..ad81e2cd 100644 --- a/.github/workflows/promptfoo-googlesheet-evaluation.yml +++ b/.github/workflows/promptfoo-googlesheet-evaluation.yml @@ -118,9 +118,9 @@ jobs: EVAL_OUTPUT_FILE="/tmp/promptfoo-output.txt" if [ -n "$PROMPTFOO_API_KEY" ]; then - promptfoo eval --max-concurrency 1 --config "/tmp/promptfooconfig.processed.yaml" --share --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}" + promptfoo eval --config "/tmp/promptfooconfig.processed.yaml" --share --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}" else - promptfoo eval --max-concurrency 1 --config "/tmp/promptfooconfig.processed.yaml" --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}" + promptfoo eval --config "/tmp/promptfooconfig.processed.yaml" --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}" fi if [ -f "${OUTPUT_JSON_FILE}" ]; then diff --git a/app/src/app_config.py b/app/src/app_config.py index 5272a60e..4411595e 100644 --- a/app/src/app_config.py +++ b/app/src/app_config.py @@ -42,8 +42,12 @@ class AppConfig(PydanticBaseEnvConfig): # If set, used instead of LITERAL_API_KEY for API literal_api_key_for_api: str | None = None + @cached_property + def db_client(self) -> db.PostgresDBClient: + return db.PostgresDBClient() + def db_session(self) -> db.Session: - return db.PostgresDBClient().get_session() + return self.db_client.get_session() @cached_property def embedding_model(self) -> EmbeddingModel: diff --git a/app/tests/conftest.py b/app/tests/conftest.py index a616589e..1687cd82 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -108,8 +108,8 @@ def enable_factory_create(monkeypatch, db_session) -> db.Session: @pytest.fixture -def app_config(monkeypatch, db_session): - monkeypatch.setattr(AppConfig, "db_session", lambda _self: db_session) +def app_config(monkeypatch, db_client: db.DBClient): + monkeypatch.setattr(AppConfig, "db_session", lambda _self: db_client.get_session()) monkeypatch.setattr(AppConfig, "embedding_model", MockEmbeddingModel()) diff --git a/app/tests/src/db/test_pg_dump_util.py b/app/tests/src/db/test_pg_dump_util.py index 827eb608..2890e824 100644 --- a/app/tests/src/db/test_pg_dump_util.py +++ b/app/tests/src/db/test_pg_dump_util.py @@ -25,7 +25,7 @@ def populated_db(db_session): ChatMessageFactory.create_batch(4, session=user_session) -def test_backup_and_restore_db(enable_factory_create, populated_db, caplog): +def test_backup_and_restore_db(enable_factory_create, populated_db, app_config, caplog): with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" pg_dump_util.backup_db(dumpfile, "local") @@ -43,7 +43,7 @@ def test_backup_and_restore_db(enable_factory_create, populated_db, caplog): assert f"DB data restored from {dumpfile!r}" in caplog.messages -def test_restore_db_without_truncating(caplog): +def test_restore_db_without_truncating(app_config, caplog): with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" with open(dumpfile, "wb"): # Create an empty file @@ -54,7 +54,7 @@ def test_restore_db_without_truncating(caplog): ) -def test_backup_db__file_exists(caplog): +def test_backup_db__file_exists(app_config, caplog): with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" with open(dumpfile, "wb"): # Create an empty file @@ -68,7 +68,7 @@ def test_backup_db__file_exists(caplog): assert f"DB data dumped to '{tmpdirname}/db.dump'" not in caplog.messages -def test_backup_db__dump_failure(caplog, monkeypatch): +def test_backup_db__dump_failure(app_config, caplog, monkeypatch): monkeypatch.setattr(pg_dump_util, "_pg_dump", lambda *args: False) with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" @@ -77,7 +77,7 @@ def test_backup_db__dump_failure(caplog, monkeypatch): assert f"Failed to dump DB data to '{tmpdirname}/db.dump'" in caplog.messages -def test_backup_db__truncate_failure(caplog, monkeypatch): +def test_backup_db__truncate_failure(app_config, caplog, monkeypatch): monkeypatch.setattr(pg_dump_util, "_truncate_db_tables", lambda *args: False) with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" @@ -95,7 +95,9 @@ def mock_s3_dev_bucket(mock_s3): yield bucket -def test_backup_db_for_dev(enable_factory_create, populated_db, caplog, mock_s3_dev_bucket): +def test_backup_db_for_dev( + enable_factory_create, populated_db, app_config, caplog, mock_s3_dev_bucket +): with caplog.at_level(logging.INFO): dumpfile = "dev_db.dump" pg_dump_util.backup_db(dumpfile, "dev") @@ -109,7 +111,7 @@ def test_backup_db_for_dev(enable_factory_create, populated_db, caplog, mock_s3_ ) -def test_restore_db_failure(caplog): +def test_restore_db_failure(app_config, caplog): with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" diff --git a/app/tests/src/evaluation/cli/test_cli_generate.py b/app/tests/src/evaluation/cli/test_cli_generate.py index 729869e6..150f9253 100644 --- a/app/tests/src/evaluation/cli/test_cli_generate.py +++ b/app/tests/src/evaluation/cli/test_cli_generate.py @@ -91,7 +91,7 @@ def test_invalid_arguments(): @pytest.mark.integration -def test_main_integration(temp_output_dir): +def test_main_integration(temp_output_dir, app_config): """Integration test with minimal test data.""" with mock.patch( "sys.argv", @@ -122,7 +122,7 @@ def test_main_integration(temp_output_dir): assert mock_qa_pairs_path.exists() -def test_error_handling_no_documents(): +def test_error_handling_no_documents(app_config): """Test handling of 'No documents found' error.""" with mock.patch("sys.argv", ["generate.py"]): with mock.patch("src.evaluation.qa_generation.runner.run_generation") as mock_run: diff --git a/app/tests/src/evaluation/qa_generation/test_qa_generation_runner.py b/app/tests/src/evaluation/qa_generation/test_qa_generation_runner.py index 07678cc9..a08a2d14 100644 --- a/app/tests/src/evaluation/qa_generation/test_qa_generation_runner.py +++ b/app/tests/src/evaluation/qa_generation/test_qa_generation_runner.py @@ -110,7 +110,7 @@ def test_qa_storage(qa_pairs, tmp_path): def test_run_generation_basic( - mock_completion_response, tmp_path, enable_factory_create, db_session + mock_completion_response, tmp_path, enable_factory_create, db_session, app_config ): """Test basic run_generation functionality.""" output_dir = tmp_path / "qa_output" @@ -147,7 +147,7 @@ def test_run_generation_basic( def test_run_generation_with_dataset_filter( - mock_completion_response, tmp_path, enable_factory_create, db_session + mock_completion_response, tmp_path, enable_factory_create, db_session, app_config ): """Test run_generation with dataset filter.""" output_dir = tmp_path / "qa_output" @@ -261,7 +261,7 @@ def test_save_qa_pairs(qa_pairs, tmp_path): def test_run_generation_with_version( - mock_completion_response, tmp_path, enable_factory_create, db_session + mock_completion_response, tmp_path, enable_factory_create, db_session, app_config ): """Test run_generation with version parameter.""" output_dir = tmp_path / "qa_output" @@ -301,7 +301,9 @@ def test_run_generation_with_version( assert pair.version.llm_model == version.llm_model -def test_run_generation_invalid_sample_fraction(tmp_path, enable_factory_create, db_session): +def test_run_generation_invalid_sample_fraction( + tmp_path, enable_factory_create, db_session, app_config +): """Test run_generation with invalid sample fraction values.""" output_dir = tmp_path / "qa_output" llm_model = "gpt-4o-mini" diff --git a/app/tests/src/test_chat_api.py b/app/tests/src/test_chat_api.py index 1a7effaa..d7aef685 100644 --- a/app/tests/src/test_chat_api.py +++ b/app/tests/src/test_chat_api.py @@ -38,7 +38,7 @@ def no_literalai_data_layer(monkeypatch): @pytest.fixture -def async_client(no_literalai_data_layer, db_session): +def async_client(no_literalai_data_layer, db_session, app_config): """ The typical FastAPI TestClient creates its own event loop to handle requests, which led to issues when testing code that relies on asynchronous operations @@ -108,7 +108,7 @@ async def test_api_engines(async_client, db_session): @pytest.mark.asyncio -async def test_api_engines__dbsession_contextvar(async_client, monkeypatch, db_session): +async def test_api_engines__dbsession_contextvar(async_client, monkeypatch, db_session, app_config): event = asyncio.Event() db_sessions = [] orig_init_chat_session = chat_api._init_chat_session diff --git a/app/tests/src/test_ingest_utils.py b/app/tests/src/test_ingest_utils.py index a24f8a60..22d7a1db 100644 --- a/app/tests/src/test_ingest_utils.py +++ b/app/tests/src/test_ingest_utils.py @@ -43,7 +43,7 @@ def test__drop_existing_dataset(db_session, enable_factory_create): ) -def test_process_and_ingest_sys_args_requires_four_args(caplog): +def test_process_and_ingest_sys_args_requires_four_args(app_config, caplog): logger = logging.getLogger(__name__) ingest = Mock() @@ -65,7 +65,7 @@ def test_process_and_ingest_sys_args_requires_four_args(caplog): assert not ingest.called -def test_process_and_ingest_sys_args_calls_ingest(caplog): +def test_process_and_ingest_sys_args_calls_ingest(app_config, caplog): logger = logging.getLogger(__name__) ingest = Mock() @@ -95,7 +95,7 @@ def test_process_and_ingest_sys_args_calls_ingest(caplog): def test_process_and_ingest_sys_args_drops_existing_dataset( - db_session, caplog, enable_factory_create + db_session, app_config, caplog, enable_factory_create ): db_session.execute(delete(Document)) logger = logging.getLogger(__name__) @@ -144,7 +144,7 @@ def test_process_and_ingest_sys_args_drops_existing_dataset( assert db_session.execute(select(Document).where(Document.dataset == "other dataset")).one() -def test_process_and_ingest_sys_args_resume(db_session, caplog, enable_factory_create): +def test_process_and_ingest_sys_args_resume(db_session, app_config, caplog, enable_factory_create): db_session.execute(delete(Document)) logger = logging.getLogger(__name__) ingest = Mock() diff --git a/app/tests/src/test_retrieve.py b/app/tests/src/test_retrieve.py index 93720fde..6c84d9cb 100644 --- a/app/tests/src/test_retrieve.py +++ b/app/tests/src/test_retrieve.py @@ -22,8 +22,8 @@ def _create_chunks(document=None): ] -def _format_retrieval_results(retrieval_results): - return [chunk_with_score.chunk for chunk_with_score in retrieval_results] +def _chunk_ids(retrieval_results): + return [chunk_with_score.chunk.id for chunk_with_score in retrieval_results] def test_retrieve__with_empty_filter(app_config, db_session, enable_factory_create): @@ -33,7 +33,7 @@ def test_retrieve__with_empty_filter(app_config, db_session, enable_factory_crea results = retrieve_with_scores( "Very tiny words.", retrieval_k=2, retrieval_k_min_score=0.0, datasets=[] ) - assert _format_retrieval_results(results) == [short_chunk, medium_chunk] + assert _chunk_ids(results) == [short_chunk.id, medium_chunk.id] def test_retrieve__with_unknown_filter(app_config, db_session, enable_factory_create): @@ -59,7 +59,7 @@ def test_retrieve__with_dataset_filter(app_config, db_session, enable_factory_cr retrieval_k_min_score=0.0, datasets=["SNAP"], ) - assert _format_retrieval_results(results) == [snap_short_chunk, snap_medium_chunk] + assert _chunk_ids(results) == [snap_short_chunk.id, snap_medium_chunk.id] def test_retrieve__with_other_filters(app_config, db_session, enable_factory_create): @@ -76,7 +76,7 @@ def test_retrieve__with_other_filters(app_config, db_session, enable_factory_cre programs=["SNAP"], regions=["MI"], ) - assert _format_retrieval_results(results) == [snap_short_chunk, snap_medium_chunk] + assert _chunk_ids(results) == [snap_short_chunk.id, snap_medium_chunk.id] def test_retrieve_with_scores(app_config, db_session, enable_factory_create): @@ -86,7 +86,7 @@ def test_retrieve_with_scores(app_config, db_session, enable_factory_create): results = retrieve_with_scores("Very tiny words.", retrieval_k=2, retrieval_k_min_score=0.0) assert len(results) == 2 - assert results[0].chunk == short_chunk + assert results[0].chunk.id == short_chunk.id assert results[0].score == 0.7071067690849304 - assert results[1].chunk == medium_chunk + assert results[1].chunk.id == medium_chunk.id assert results[1].score == 0.25881901383399963