From baf3708415e0e4d79638aeb5c5b6951977820d60 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 17 Jun 2025 15:50:33 -0500 Subject: [PATCH 01/19] use single db.PostgresDBClient instance --- app/src/app_config.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/app/src/app_config.py b/app/src/app_config.py index 5272a60e..4411595e 100644 --- a/app/src/app_config.py +++ b/app/src/app_config.py @@ -42,8 +42,12 @@ class AppConfig(PydanticBaseEnvConfig): # If set, used instead of LITERAL_API_KEY for API literal_api_key_for_api: str | None = None + @cached_property + def db_client(self) -> db.PostgresDBClient: + return db.PostgresDBClient() + def db_session(self) -> db.Session: - return db.PostgresDBClient().get_session() + return self.db_client.get_session() @cached_property def embedding_model(self) -> EmbeddingModel: From 0a335fe0076b029d856655e4fbb47b9314bf8f0a Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 17 Jun 2025 15:50:59 -0500 Subject: [PATCH 02/19] allow promptfoo to use more than 1 thread --- .github/workflows/promptfoo-googlesheet-evaluation.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/promptfoo-googlesheet-evaluation.yml b/.github/workflows/promptfoo-googlesheet-evaluation.yml index 232907c3..ad81e2cd 100644 --- a/.github/workflows/promptfoo-googlesheet-evaluation.yml +++ b/.github/workflows/promptfoo-googlesheet-evaluation.yml @@ -118,9 +118,9 @@ jobs: EVAL_OUTPUT_FILE="/tmp/promptfoo-output.txt" if [ -n "$PROMPTFOO_API_KEY" ]; then - promptfoo eval --max-concurrency 1 --config "/tmp/promptfooconfig.processed.yaml" --share --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}" + promptfoo eval --config "/tmp/promptfooconfig.processed.yaml" --share --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}" else - promptfoo eval --max-concurrency 1 --config "/tmp/promptfooconfig.processed.yaml" --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}" + promptfoo eval --config "/tmp/promptfooconfig.processed.yaml" --output "${OUTPUT_JSON_FILE}" --no-cache | tee "${EVAL_OUTPUT_FILE}" fi if [ -f "${OUTPUT_JSON_FILE}" ]; then From e85f9fbb5af9eddee81ad82d98db01c164f3db25 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 17 Jun 2025 15:51:17 -0500 Subject: [PATCH 03/19] add Gemini models --- app/src/generate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/app/src/generate.py b/app/src/generate.py index 85b19832..8dc56e66 100644 --- a/app/src/generate.py +++ b/app/src/generate.py @@ -12,7 +12,6 @@ logger = logging.getLogger(__name__) - def get_models() -> dict[str, str]: """ Returns a dictionary of the available models, based on @@ -25,6 +24,11 @@ def get_models() -> dict[str, str]: models |= {"OpenAI GPT-4o": "gpt-4o"} if "ANTHROPIC_API_KEY" in os.environ: models |= {"Anthropic Claude 3.5 Sonnet": "claude-3-5-sonnet-20240620"} + if "GEMINI_API_KEY" in os.environ: + models |= { + "Google Gemini 1.5 Pro": "gemini/gemini-1.5-pro", + "Google Gemini 2.5 Pro": "gemini/gemini-2.5-pro-preview-06-05", + } if _has_aws_access(): # If you get "You don't have access to the model with the specified model ID." error, # remember to request access to Bedrock models ...aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess From b2ce4aa09272f130b79fd87fe99235bc31af86c2 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 17 Jun 2025 15:58:48 -0500 Subject: [PATCH 04/19] temp: use gemini 2.5 pro for promptfoo eval --- app/src/chat_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/src/chat_engine.py b/app/src/chat_engine.py index ec7d1913..30756234 100644 --- a/app/src/chat_engine.py +++ b/app/src/chat_engine.py @@ -335,6 +335,7 @@ class ImagineLaEngine(BaseEngine): ] engine_id: str = "imagine-la" + llm: str = "gemini/gemini-2.5-pro-preview-06-05" name: str = "SBN Chat Engine" datasets = [ "CA EDD", From e709932a9ff3fcb60fd1f4ce53048715f0781a8c Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 17 Jun 2025 16:08:37 -0500 Subject: [PATCH 05/19] de-lint --- app/src/generate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/src/generate.py b/app/src/generate.py index 8dc56e66..d9add286 100644 --- a/app/src/generate.py +++ b/app/src/generate.py @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) + def get_models() -> dict[str, str]: """ Returns a dictionary of the available models, based on From ad5c1f0d4f3d394df5e5be55204e6be8b5a9e89e Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 17 Jun 2025 16:36:38 -0500 Subject: [PATCH 06/19] add GEMINI_API_KEY to promptfoo action --- .github/workflows/promptfoo-googlesheet-evaluation.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/promptfoo-googlesheet-evaluation.yml b/.github/workflows/promptfoo-googlesheet-evaluation.yml index ad81e2cd..fb21eee5 100644 --- a/.github/workflows/promptfoo-googlesheet-evaluation.yml +++ b/.github/workflows/promptfoo-googlesheet-evaluation.yml @@ -97,6 +97,7 @@ jobs: env: GOOGLE_APPLICATION_CREDENTIALS: /tmp/gcp-creds.json OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} PROMPTFOO_FAILED_TEST_EXIT_CODE: 0 run: | if [ ! -f "$GOOGLE_APPLICATION_CREDENTIALS" ]; then From e340876ac64e3a31ea2f989e77cb3440fbffee8b Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 17 Jun 2025 18:08:55 -0500 Subject: [PATCH 07/19] add app_config fixture to failing tests --- app/tests/src/db/test_pg_dump_util.py | 6 ++++-- .../qa_generation/test_qa_generation_runner.py | 10 ++++++---- app/tests/src/test_chat_api.py | 2 +- app/tests/src/test_generate.py | 2 ++ app/tests/src/test_ingest_utils.py | 4 ++-- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/app/tests/src/db/test_pg_dump_util.py b/app/tests/src/db/test_pg_dump_util.py index 827eb608..403b8e76 100644 --- a/app/tests/src/db/test_pg_dump_util.py +++ b/app/tests/src/db/test_pg_dump_util.py @@ -25,7 +25,7 @@ def populated_db(db_session): ChatMessageFactory.create_batch(4, session=user_session) -def test_backup_and_restore_db(enable_factory_create, populated_db, caplog): +def test_backup_and_restore_db(enable_factory_create, populated_db, app_config, caplog): with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" pg_dump_util.backup_db(dumpfile, "local") @@ -95,7 +95,9 @@ def mock_s3_dev_bucket(mock_s3): yield bucket -def test_backup_db_for_dev(enable_factory_create, populated_db, caplog, mock_s3_dev_bucket): +def test_backup_db_for_dev( + enable_factory_create, populated_db, app_config, caplog, mock_s3_dev_bucket +): with caplog.at_level(logging.INFO): dumpfile = "dev_db.dump" pg_dump_util.backup_db(dumpfile, "dev") diff --git a/app/tests/src/evaluation/qa_generation/test_qa_generation_runner.py b/app/tests/src/evaluation/qa_generation/test_qa_generation_runner.py index 07678cc9..a08a2d14 100644 --- a/app/tests/src/evaluation/qa_generation/test_qa_generation_runner.py +++ b/app/tests/src/evaluation/qa_generation/test_qa_generation_runner.py @@ -110,7 +110,7 @@ def test_qa_storage(qa_pairs, tmp_path): def test_run_generation_basic( - mock_completion_response, tmp_path, enable_factory_create, db_session + mock_completion_response, tmp_path, enable_factory_create, db_session, app_config ): """Test basic run_generation functionality.""" output_dir = tmp_path / "qa_output" @@ -147,7 +147,7 @@ def test_run_generation_basic( def test_run_generation_with_dataset_filter( - mock_completion_response, tmp_path, enable_factory_create, db_session + mock_completion_response, tmp_path, enable_factory_create, db_session, app_config ): """Test run_generation with dataset filter.""" output_dir = tmp_path / "qa_output" @@ -261,7 +261,7 @@ def test_save_qa_pairs(qa_pairs, tmp_path): def test_run_generation_with_version( - mock_completion_response, tmp_path, enable_factory_create, db_session + mock_completion_response, tmp_path, enable_factory_create, db_session, app_config ): """Test run_generation with version parameter.""" output_dir = tmp_path / "qa_output" @@ -301,7 +301,9 @@ def test_run_generation_with_version( assert pair.version.llm_model == version.llm_model -def test_run_generation_invalid_sample_fraction(tmp_path, enable_factory_create, db_session): +def test_run_generation_invalid_sample_fraction( + tmp_path, enable_factory_create, db_session, app_config +): """Test run_generation with invalid sample fraction values.""" output_dir = tmp_path / "qa_output" llm_model = "gpt-4o-mini" diff --git a/app/tests/src/test_chat_api.py b/app/tests/src/test_chat_api.py index 1a7effaa..eaf096f3 100644 --- a/app/tests/src/test_chat_api.py +++ b/app/tests/src/test_chat_api.py @@ -38,7 +38,7 @@ def no_literalai_data_layer(monkeypatch): @pytest.fixture -def async_client(no_literalai_data_layer, db_session): +def async_client(no_literalai_data_layer, db_session, app_config): """ The typical FastAPI TestClient creates its own event loop to handle requests, which led to issues when testing code that relies on asynchronous operations diff --git a/app/tests/src/test_generate.py b/app/tests/src/test_generate.py index 842c2c98..539d669d 100644 --- a/app/tests/src/test_generate.py +++ b/app/tests/src/test_generate.py @@ -81,6 +81,8 @@ def test_get_models(monkeypatch): monkeypatch.delenv("AWS_ACCESS_KEY_ID") if "OPENAI_API_KEY" in os.environ: monkeypatch.delenv("OPENAI_API_KEY") + if "GEMINI_API_KEY" in os.environ: + monkeypatch.delenv("GEMINI_API_KEY") if "OLLAMA_HOST" in os.environ: monkeypatch.delenv("OLLAMA_HOST") assert get_models() == {} diff --git a/app/tests/src/test_ingest_utils.py b/app/tests/src/test_ingest_utils.py index a24f8a60..b1c6a6e5 100644 --- a/app/tests/src/test_ingest_utils.py +++ b/app/tests/src/test_ingest_utils.py @@ -95,7 +95,7 @@ def test_process_and_ingest_sys_args_calls_ingest(caplog): def test_process_and_ingest_sys_args_drops_existing_dataset( - db_session, caplog, enable_factory_create + db_session, app_config, caplog, enable_factory_create ): db_session.execute(delete(Document)) logger = logging.getLogger(__name__) @@ -144,7 +144,7 @@ def test_process_and_ingest_sys_args_drops_existing_dataset( assert db_session.execute(select(Document).where(Document.dataset == "other dataset")).one() -def test_process_and_ingest_sys_args_resume(db_session, caplog, enable_factory_create): +def test_process_and_ingest_sys_args_resume(db_session, app_config, caplog, enable_factory_create): db_session.execute(delete(Document)) logger = logging.getLogger(__name__) ingest = Mock() From 5ff8620d0c533ba28224a1d4d9c06a32cf17bc63 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 17 Jun 2025 18:09:30 -0500 Subject: [PATCH 08/19] TEMP: ignore some tests --- app/tests/src/evaluation/cli/test_cli_generate.py | 5 +++++ app/tests/src/test_chat_api.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/app/tests/src/evaluation/cli/test_cli_generate.py b/app/tests/src/evaluation/cli/test_cli_generate.py index 729869e6..0bfe77ea 100644 --- a/app/tests/src/evaluation/cli/test_cli_generate.py +++ b/app/tests/src/evaluation/cli/test_cli_generate.py @@ -32,6 +32,7 @@ def test_dataset_mapping(): assert map_dataset_name("unknown_dataset") == "unknown_dataset" +@pytest.mark.skip(reason="") def test_argument_parsing(): """Test argument parsing with various combinations.""" # Test default values @@ -74,6 +75,7 @@ def validate_sampling_fraction(value): raise argparse.ArgumentTypeError(f"{value} is not a valid sampling fraction") from err +@pytest.mark.skip(reason="") def test_invalid_arguments(): """Test handling of invalid arguments.""" parser = generate.create_parser() @@ -90,6 +92,7 @@ def test_invalid_arguments(): parser.parse_args(["--random-seed", "not_a_number"]) +@pytest.mark.skip(reason="") @pytest.mark.integration def test_main_integration(temp_output_dir): """Integration test with minimal test data.""" @@ -122,6 +125,7 @@ def test_main_integration(temp_output_dir): assert mock_qa_pairs_path.exists() +@pytest.mark.skip(reason="") def test_error_handling_no_documents(): """Test handling of 'No documents found' error.""" with mock.patch("sys.argv", ["generate.py"]): @@ -130,6 +134,7 @@ def test_error_handling_no_documents(): generate.main() # This should handle the error and return +@pytest.mark.skip(reason="") def test_output_directory_handling(temp_output_dir): """Test output directory path handling.""" # Test relative path diff --git a/app/tests/src/test_chat_api.py b/app/tests/src/test_chat_api.py index eaf096f3..ba10ae23 100644 --- a/app/tests/src/test_chat_api.py +++ b/app/tests/src/test_chat_api.py @@ -107,6 +107,7 @@ async def test_api_engines(async_client, db_session): assert db_session.query(Feedback).count() == 0 +@pytest.mark.skip(reason="") @pytest.mark.asyncio async def test_api_engines__dbsession_contextvar(async_client, monkeypatch, db_session): event = asyncio.Event() @@ -160,6 +161,7 @@ async def mock_run_query(engine, question, chat_history): ) +@pytest.mark.skip(reason="") @pytest.mark.asyncio async def test_api_query(async_client, monkeypatch, db_session): monkeypatch.setattr("src.chat_api.run_query", mock_run_query) From bb120975bf78301d43970f0b0d91d8b5bf2881bf Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 17 Jun 2025 18:24:22 -0500 Subject: [PATCH 09/19] fix more tests --- app/tests/src/db/test_pg_dump_util.py | 10 +++++----- app/tests/src/test_ingest_utils.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/app/tests/src/db/test_pg_dump_util.py b/app/tests/src/db/test_pg_dump_util.py index 403b8e76..2890e824 100644 --- a/app/tests/src/db/test_pg_dump_util.py +++ b/app/tests/src/db/test_pg_dump_util.py @@ -43,7 +43,7 @@ def test_backup_and_restore_db(enable_factory_create, populated_db, app_config, assert f"DB data restored from {dumpfile!r}" in caplog.messages -def test_restore_db_without_truncating(caplog): +def test_restore_db_without_truncating(app_config, caplog): with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" with open(dumpfile, "wb"): # Create an empty file @@ -54,7 +54,7 @@ def test_restore_db_without_truncating(caplog): ) -def test_backup_db__file_exists(caplog): +def test_backup_db__file_exists(app_config, caplog): with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" with open(dumpfile, "wb"): # Create an empty file @@ -68,7 +68,7 @@ def test_backup_db__file_exists(caplog): assert f"DB data dumped to '{tmpdirname}/db.dump'" not in caplog.messages -def test_backup_db__dump_failure(caplog, monkeypatch): +def test_backup_db__dump_failure(app_config, caplog, monkeypatch): monkeypatch.setattr(pg_dump_util, "_pg_dump", lambda *args: False) with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" @@ -77,7 +77,7 @@ def test_backup_db__dump_failure(caplog, monkeypatch): assert f"Failed to dump DB data to '{tmpdirname}/db.dump'" in caplog.messages -def test_backup_db__truncate_failure(caplog, monkeypatch): +def test_backup_db__truncate_failure(app_config, caplog, monkeypatch): monkeypatch.setattr(pg_dump_util, "_truncate_db_tables", lambda *args: False) with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" @@ -111,7 +111,7 @@ def test_backup_db_for_dev( ) -def test_restore_db_failure(caplog): +def test_restore_db_failure(app_config, caplog): with caplog.at_level(logging.INFO), tempfile.TemporaryDirectory() as tmpdirname: dumpfile = f"{tmpdirname}/db.dump" diff --git a/app/tests/src/test_ingest_utils.py b/app/tests/src/test_ingest_utils.py index b1c6a6e5..22d7a1db 100644 --- a/app/tests/src/test_ingest_utils.py +++ b/app/tests/src/test_ingest_utils.py @@ -43,7 +43,7 @@ def test__drop_existing_dataset(db_session, enable_factory_create): ) -def test_process_and_ingest_sys_args_requires_four_args(caplog): +def test_process_and_ingest_sys_args_requires_four_args(app_config, caplog): logger = logging.getLogger(__name__) ingest = Mock() @@ -65,7 +65,7 @@ def test_process_and_ingest_sys_args_requires_four_args(caplog): assert not ingest.called -def test_process_and_ingest_sys_args_calls_ingest(caplog): +def test_process_and_ingest_sys_args_calls_ingest(app_config, caplog): logger = logging.getLogger(__name__) ingest = Mock() From cf2fd492c3efd11f4dad90d134adf9778797f0c3 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Tue, 17 Jun 2025 18:40:55 -0500 Subject: [PATCH 10/19] WIP --- app/src/app_config.py | 7 ++++++- app/src/chat_engine.py | 2 +- app/src/db/pg_dump_util.py | 5 +++++ app/tests/src/test_chat_api.py | 4 ++-- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/app/src/app_config.py b/app/src/app_config.py index 4411595e..579d86fd 100644 --- a/app/src/app_config.py +++ b/app/src/app_config.py @@ -47,7 +47,12 @@ def db_client(self) -> db.PostgresDBClient: return db.PostgresDBClient() def db_session(self) -> db.Session: - return self.db_client.get_session() + import pdb + pdb.set_trace() + return db.PostgresDBClient().get_session() + + # def db_session(self) -> db.Session: + # return self.db_client.get_session() @cached_property def embedding_model(self) -> EmbeddingModel: diff --git a/app/src/chat_engine.py b/app/src/chat_engine.py index 30756234..35f615df 100644 --- a/app/src/chat_engine.py +++ b/app/src/chat_engine.py @@ -335,7 +335,7 @@ class ImagineLaEngine(BaseEngine): ] engine_id: str = "imagine-la" - llm: str = "gemini/gemini-2.5-pro-preview-06-05" + # llm: str = "gemini/gemini-2.5-pro-preview-06-05" name: str = "SBN Chat Engine" datasets = [ "CA EDD", diff --git a/app/src/db/pg_dump_util.py b/app/src/db/pg_dump_util.py index 7bf8feb3..ed59f2cb 100644 --- a/app/src/db/pg_dump_util.py +++ b/app/src/db/pg_dump_util.py @@ -43,6 +43,8 @@ def backup_db(dumpfilename: str, env: str) -> None: _print_row_counts() logger.info("Writing DB dump to %r", dumpfilename) config_dict = _get_db_config() + print("---------------DB config:") + print(config_dict) if not _pg_dump(config_dict, dumpfilename): logger.fatal("Failed to dump DB data to %r", dumpfilename) return @@ -158,6 +160,7 @@ def _pg_dump(config_dict: dict[str, str], stdout_file: str) -> bool: os.environ["PGPASSWORD"] = config_dict["password"] # Unit test sets DB_SCHEMA to avoid affecting the real DB schema = _get_schema_name() + print("schema:", schema) command = [ "pg_dump", "--data-only", @@ -170,6 +173,8 @@ def _pg_dump(config_dict: dict[str, str], stdout_file: str) -> bool: schema, config_dict["dbname"], ] + print("command:", " ".join(command)) + # import pdb; pdb.set_trace() return _run_command(command, dumpfile) diff --git a/app/tests/src/test_chat_api.py b/app/tests/src/test_chat_api.py index ba10ae23..e0e955bc 100644 --- a/app/tests/src/test_chat_api.py +++ b/app/tests/src/test_chat_api.py @@ -107,15 +107,15 @@ async def test_api_engines(async_client, db_session): assert db_session.query(Feedback).count() == 0 -@pytest.mark.skip(reason="") @pytest.mark.asyncio -async def test_api_engines__dbsession_contextvar(async_client, monkeypatch, db_session): +async def test_api_engines__dbsession_contextvar(async_client, monkeypatch, db_session, app_config): event = asyncio.Event() db_sessions = [] orig_init_chat_session = chat_api._init_chat_session async def wait_for_all_requests(_self, *_args): db_session = chat_api.dbsession.get() + print(f"DB session in wait_for_all_requests: {db_session}") db_sessions.append(db_session) if len(db_sessions) < 2: # Wait to allow the event loop to run other tasks From 5b8411e7425a39c5d470718634cf6fac50f9a285 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Wed, 18 Jun 2025 10:17:47 -0500 Subject: [PATCH 11/19] add GEMINI_API_KEY to terraform --- infra/app/app-config/env-config/environment_variables.tf | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/infra/app/app-config/env-config/environment_variables.tf b/infra/app/app-config/env-config/environment_variables.tf index 38f3155b..348fc4dc 100644 --- a/infra/app/app-config/env-config/environment_variables.tf +++ b/infra/app/app-config/env-config/environment_variables.tf @@ -36,6 +36,10 @@ locals { manage_method = "manual" secret_store_name = "/${var.app_name}-${var.environment}/OPENAI_API_KEY" } + GEMINI_API_KEY = { + manage_method = "manual" + secret_store_name = "/${var.app_name}-${var.environment}/GEMINI_API_KEY" + } LITERAL_API_KEY = { manage_method = "manual" secret_store_name = "/${var.app_name}-${var.environment}/LITERAL_API_KEY" From 7906d97a12830d8676d09a4bef36fc01a67747f9 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Wed, 18 Jun 2025 10:21:48 -0500 Subject: [PATCH 12/19] revert debugging code --- app/src/app_config.py | 12 ++++++------ app/src/chat_engine.py | 2 +- app/tests/conftest.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/app/src/app_config.py b/app/src/app_config.py index 579d86fd..3f6aeb9d 100644 --- a/app/src/app_config.py +++ b/app/src/app_config.py @@ -46,13 +46,13 @@ class AppConfig(PydanticBaseEnvConfig): def db_client(self) -> db.PostgresDBClient: return db.PostgresDBClient() - def db_session(self) -> db.Session: - import pdb - pdb.set_trace() - return db.PostgresDBClient().get_session() - # def db_session(self) -> db.Session: - # return self.db_client.get_session() + # import pdb + # pdb.set_trace() + # return db.PostgresDBClient().get_session() + + def db_session(self) -> db.Session: + return self.db_client.get_session() @cached_property def embedding_model(self) -> EmbeddingModel: diff --git a/app/src/chat_engine.py b/app/src/chat_engine.py index 35f615df..30756234 100644 --- a/app/src/chat_engine.py +++ b/app/src/chat_engine.py @@ -335,7 +335,7 @@ class ImagineLaEngine(BaseEngine): ] engine_id: str = "imagine-la" - # llm: str = "gemini/gemini-2.5-pro-preview-06-05" + llm: str = "gemini/gemini-2.5-pro-preview-06-05" name: str = "SBN Chat Engine" datasets = [ "CA EDD", diff --git a/app/tests/conftest.py b/app/tests/conftest.py index a616589e..1687cd82 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -108,8 +108,8 @@ def enable_factory_create(monkeypatch, db_session) -> db.Session: @pytest.fixture -def app_config(monkeypatch, db_session): - monkeypatch.setattr(AppConfig, "db_session", lambda _self: db_session) +def app_config(monkeypatch, db_client: db.DBClient): + monkeypatch.setattr(AppConfig, "db_session", lambda _self: db_client.get_session()) monkeypatch.setattr(AppConfig, "embedding_model", MockEmbeddingModel()) From 0e3c0a866c0f4afa91123d649123c6912e330dc8 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Wed, 18 Jun 2025 12:17:41 -0500 Subject: [PATCH 13/19] add Gemini LLM --- app/src/generate.py | 5 +++++ app/tests/src/test_generate.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/app/src/generate.py b/app/src/generate.py index 85b19832..d9add286 100644 --- a/app/src/generate.py +++ b/app/src/generate.py @@ -25,6 +25,11 @@ def get_models() -> dict[str, str]: models |= {"OpenAI GPT-4o": "gpt-4o"} if "ANTHROPIC_API_KEY" in os.environ: models |= {"Anthropic Claude 3.5 Sonnet": "claude-3-5-sonnet-20240620"} + if "GEMINI_API_KEY" in os.environ: + models |= { + "Google Gemini 1.5 Pro": "gemini/gemini-1.5-pro", + "Google Gemini 2.5 Pro": "gemini/gemini-2.5-pro-preview-06-05", + } if _has_aws_access(): # If you get "You don't have access to the model with the specified model ID." error, # remember to request access to Bedrock models ...aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess diff --git a/app/tests/src/test_generate.py b/app/tests/src/test_generate.py index 842c2c98..539d669d 100644 --- a/app/tests/src/test_generate.py +++ b/app/tests/src/test_generate.py @@ -81,6 +81,8 @@ def test_get_models(monkeypatch): monkeypatch.delenv("AWS_ACCESS_KEY_ID") if "OPENAI_API_KEY" in os.environ: monkeypatch.delenv("OPENAI_API_KEY") + if "GEMINI_API_KEY" in os.environ: + monkeypatch.delenv("GEMINI_API_KEY") if "OLLAMA_HOST" in os.environ: monkeypatch.delenv("OLLAMA_HOST") assert get_models() == {} From 9b001fd147d4690441b0fe40613e360af59cf107 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Wed, 18 Jun 2025 12:43:54 -0500 Subject: [PATCH 14/19] remove temporary code --- .github/workflows/promptfoo-googlesheet-evaluation.yml | 1 - app/src/chat_engine.py | 1 - 2 files changed, 2 deletions(-) diff --git a/.github/workflows/promptfoo-googlesheet-evaluation.yml b/.github/workflows/promptfoo-googlesheet-evaluation.yml index fb21eee5..ad81e2cd 100644 --- a/.github/workflows/promptfoo-googlesheet-evaluation.yml +++ b/.github/workflows/promptfoo-googlesheet-evaluation.yml @@ -97,7 +97,6 @@ jobs: env: GOOGLE_APPLICATION_CREDENTIALS: /tmp/gcp-creds.json OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} PROMPTFOO_FAILED_TEST_EXIT_CODE: 0 run: | if [ ! -f "$GOOGLE_APPLICATION_CREDENTIALS" ]; then diff --git a/app/src/chat_engine.py b/app/src/chat_engine.py index 30756234..ec7d1913 100644 --- a/app/src/chat_engine.py +++ b/app/src/chat_engine.py @@ -335,7 +335,6 @@ class ImagineLaEngine(BaseEngine): ] engine_id: str = "imagine-la" - llm: str = "gemini/gemini-2.5-pro-preview-06-05" name: str = "SBN Chat Engine" datasets = [ "CA EDD", From dcd8e45b40358466b79c5aba122a366ec5d9aa71 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Wed, 18 Jun 2025 12:45:49 -0500 Subject: [PATCH 15/19] add to terraform --- infra/app/app-config/env-config/environment_variables.tf | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/infra/app/app-config/env-config/environment_variables.tf b/infra/app/app-config/env-config/environment_variables.tf index 38f3155b..348fc4dc 100644 --- a/infra/app/app-config/env-config/environment_variables.tf +++ b/infra/app/app-config/env-config/environment_variables.tf @@ -36,6 +36,10 @@ locals { manage_method = "manual" secret_store_name = "/${var.app_name}-${var.environment}/OPENAI_API_KEY" } + GEMINI_API_KEY = { + manage_method = "manual" + secret_store_name = "/${var.app_name}-${var.environment}/GEMINI_API_KEY" + } LITERAL_API_KEY = { manage_method = "manual" secret_store_name = "/${var.app_name}-${var.environment}/LITERAL_API_KEY" From 96972312d4fabe3198ec449d2af200f3594b17fa Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Wed, 18 Jun 2025 13:48:49 -0500 Subject: [PATCH 16/19] compare by chunk id --- app/tests/src/test_retrieve.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/app/tests/src/test_retrieve.py b/app/tests/src/test_retrieve.py index 93720fde..6c84d9cb 100644 --- a/app/tests/src/test_retrieve.py +++ b/app/tests/src/test_retrieve.py @@ -22,8 +22,8 @@ def _create_chunks(document=None): ] -def _format_retrieval_results(retrieval_results): - return [chunk_with_score.chunk for chunk_with_score in retrieval_results] +def _chunk_ids(retrieval_results): + return [chunk_with_score.chunk.id for chunk_with_score in retrieval_results] def test_retrieve__with_empty_filter(app_config, db_session, enable_factory_create): @@ -33,7 +33,7 @@ def test_retrieve__with_empty_filter(app_config, db_session, enable_factory_crea results = retrieve_with_scores( "Very tiny words.", retrieval_k=2, retrieval_k_min_score=0.0, datasets=[] ) - assert _format_retrieval_results(results) == [short_chunk, medium_chunk] + assert _chunk_ids(results) == [short_chunk.id, medium_chunk.id] def test_retrieve__with_unknown_filter(app_config, db_session, enable_factory_create): @@ -59,7 +59,7 @@ def test_retrieve__with_dataset_filter(app_config, db_session, enable_factory_cr retrieval_k_min_score=0.0, datasets=["SNAP"], ) - assert _format_retrieval_results(results) == [snap_short_chunk, snap_medium_chunk] + assert _chunk_ids(results) == [snap_short_chunk.id, snap_medium_chunk.id] def test_retrieve__with_other_filters(app_config, db_session, enable_factory_create): @@ -76,7 +76,7 @@ def test_retrieve__with_other_filters(app_config, db_session, enable_factory_cre programs=["SNAP"], regions=["MI"], ) - assert _format_retrieval_results(results) == [snap_short_chunk, snap_medium_chunk] + assert _chunk_ids(results) == [snap_short_chunk.id, snap_medium_chunk.id] def test_retrieve_with_scores(app_config, db_session, enable_factory_create): @@ -86,7 +86,7 @@ def test_retrieve_with_scores(app_config, db_session, enable_factory_create): results = retrieve_with_scores("Very tiny words.", retrieval_k=2, retrieval_k_min_score=0.0) assert len(results) == 2 - assert results[0].chunk == short_chunk + assert results[0].chunk.id == short_chunk.id assert results[0].score == 0.7071067690849304 - assert results[1].chunk == medium_chunk + assert results[1].chunk.id == medium_chunk.id assert results[1].score == 0.25881901383399963 From c22c07da80144c7da79d97e57635bed607f5b524 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Wed, 18 Jun 2025 15:27:47 -0500 Subject: [PATCH 17/19] fix more tests --- app/tests/src/evaluation/cli/test_cli_generate.py | 9 ++------- app/tests/src/test_chat_api.py | 1 - 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/app/tests/src/evaluation/cli/test_cli_generate.py b/app/tests/src/evaluation/cli/test_cli_generate.py index 0bfe77ea..150f9253 100644 --- a/app/tests/src/evaluation/cli/test_cli_generate.py +++ b/app/tests/src/evaluation/cli/test_cli_generate.py @@ -32,7 +32,6 @@ def test_dataset_mapping(): assert map_dataset_name("unknown_dataset") == "unknown_dataset" -@pytest.mark.skip(reason="") def test_argument_parsing(): """Test argument parsing with various combinations.""" # Test default values @@ -75,7 +74,6 @@ def validate_sampling_fraction(value): raise argparse.ArgumentTypeError(f"{value} is not a valid sampling fraction") from err -@pytest.mark.skip(reason="") def test_invalid_arguments(): """Test handling of invalid arguments.""" parser = generate.create_parser() @@ -92,9 +90,8 @@ def test_invalid_arguments(): parser.parse_args(["--random-seed", "not_a_number"]) -@pytest.mark.skip(reason="") @pytest.mark.integration -def test_main_integration(temp_output_dir): +def test_main_integration(temp_output_dir, app_config): """Integration test with minimal test data.""" with mock.patch( "sys.argv", @@ -125,8 +122,7 @@ def test_main_integration(temp_output_dir): assert mock_qa_pairs_path.exists() -@pytest.mark.skip(reason="") -def test_error_handling_no_documents(): +def test_error_handling_no_documents(app_config): """Test handling of 'No documents found' error.""" with mock.patch("sys.argv", ["generate.py"]): with mock.patch("src.evaluation.qa_generation.runner.run_generation") as mock_run: @@ -134,7 +130,6 @@ def test_error_handling_no_documents(): generate.main() # This should handle the error and return -@pytest.mark.skip(reason="") def test_output_directory_handling(temp_output_dir): """Test output directory path handling.""" # Test relative path diff --git a/app/tests/src/test_chat_api.py b/app/tests/src/test_chat_api.py index e0e955bc..d3352bf5 100644 --- a/app/tests/src/test_chat_api.py +++ b/app/tests/src/test_chat_api.py @@ -161,7 +161,6 @@ async def mock_run_query(engine, question, chat_history): ) -@pytest.mark.skip(reason="") @pytest.mark.asyncio async def test_api_query(async_client, monkeypatch, db_session): monkeypatch.setattr("src.chat_api.run_query", mock_run_query) From 8e89b00e88d7ade65705dce1df3d7eb9e8ce32a7 Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Wed, 18 Jun 2025 15:29:25 -0500 Subject: [PATCH 18/19] remove print() --- app/src/db/pg_dump_util.py | 5 ----- app/tests/src/test_chat_api.py | 1 - 2 files changed, 6 deletions(-) diff --git a/app/src/db/pg_dump_util.py b/app/src/db/pg_dump_util.py index ed59f2cb..7bf8feb3 100644 --- a/app/src/db/pg_dump_util.py +++ b/app/src/db/pg_dump_util.py @@ -43,8 +43,6 @@ def backup_db(dumpfilename: str, env: str) -> None: _print_row_counts() logger.info("Writing DB dump to %r", dumpfilename) config_dict = _get_db_config() - print("---------------DB config:") - print(config_dict) if not _pg_dump(config_dict, dumpfilename): logger.fatal("Failed to dump DB data to %r", dumpfilename) return @@ -160,7 +158,6 @@ def _pg_dump(config_dict: dict[str, str], stdout_file: str) -> bool: os.environ["PGPASSWORD"] = config_dict["password"] # Unit test sets DB_SCHEMA to avoid affecting the real DB schema = _get_schema_name() - print("schema:", schema) command = [ "pg_dump", "--data-only", @@ -173,8 +170,6 @@ def _pg_dump(config_dict: dict[str, str], stdout_file: str) -> bool: schema, config_dict["dbname"], ] - print("command:", " ".join(command)) - # import pdb; pdb.set_trace() return _run_command(command, dumpfile) diff --git a/app/tests/src/test_chat_api.py b/app/tests/src/test_chat_api.py index d3352bf5..d7aef685 100644 --- a/app/tests/src/test_chat_api.py +++ b/app/tests/src/test_chat_api.py @@ -115,7 +115,6 @@ async def test_api_engines__dbsession_contextvar(async_client, monkeypatch, db_s async def wait_for_all_requests(_self, *_args): db_session = chat_api.dbsession.get() - print(f"DB session in wait_for_all_requests: {db_session}") db_sessions.append(db_session) if len(db_sessions) < 2: # Wait to allow the event loop to run other tasks From a960bb4a281505109fc0cba00d66b5907450cabe Mon Sep 17 00:00:00 2001 From: Yoom Lam Date: Mon, 23 Jun 2025 09:41:41 -0500 Subject: [PATCH 19/19] remove old code --- app/src/app_config.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/app/src/app_config.py b/app/src/app_config.py index 3f6aeb9d..4411595e 100644 --- a/app/src/app_config.py +++ b/app/src/app_config.py @@ -46,11 +46,6 @@ class AppConfig(PydanticBaseEnvConfig): def db_client(self) -> db.PostgresDBClient: return db.PostgresDBClient() - # def db_session(self) -> db.Session: - # import pdb - # pdb.set_trace() - # return db.PostgresDBClient().get_session() - def db_session(self) -> db.Session: return self.db_client.get_session()