Skip to content

Commit 674b1ee

Browse files
move Postgres settings into separate settings class (#209)
* make postgres settings values optional (but validate later) * split Postgres settings into separate PostgresSettings class * update changelog * fix a few tests --------- Co-authored-by: vincentsarago <vincent.sarago@gmail.com>
1 parent d379c89 commit 674b1ee

File tree

6 files changed

+113
-58
lines changed

6 files changed

+113
-58
lines changed

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## [Unreleased]
44

5+
### Changed
6+
7+
- move Postgres settings into separate `PostgresSettings` class and defer loading until connecting to database ([#209](https://github.com/stac-utils/stac-fastapi-pgstac/pull/209))
8+
59
## [4.0.3] - 2025-03-10
610

711
### Fixed

stac_fastapi/pgstac/config.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from urllib.parse import quote_plus as quote
55

66
from pydantic import BaseModel, field_validator
7-
from pydantic_settings import SettingsConfigDict
7+
from pydantic_settings import BaseSettings, SettingsConfigDict
88
from stac_fastapi.types.config import ApiSettings
99

1010
from stac_fastapi.pgstac.types.base_item_cache import (
@@ -43,7 +43,7 @@ class ServerSettings(BaseModel):
4343
model_config = SettingsConfigDict(extra="allow")
4444

4545

46-
class Settings(ApiSettings):
46+
class PostgresSettings(BaseSettings):
4747
"""Postgres-specific API settings.
4848
4949
Attributes:
@@ -71,9 +71,28 @@ class Settings(ApiSettings):
7171

7272
server_settings: ServerSettings = ServerSettings()
7373

74+
model_config = {"env_file": ".env", "extra": "ignore"}
75+
76+
@property
77+
def reader_connection_string(self):
78+
"""Create reader psql connection string."""
79+
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}"
80+
81+
@property
82+
def writer_connection_string(self):
83+
"""Create writer psql connection string."""
84+
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}"
85+
86+
@property
87+
def testing_connection_string(self):
88+
"""Create testing psql connection string."""
89+
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb"
90+
91+
92+
class Settings(ApiSettings):
7493
use_api_hydrate: bool = False
75-
base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache
7694
invalid_id_chars: List[str] = DEFAULT_INVALID_ID_CHARS
95+
base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache
7796

7897
cors_origins: str = "*"
7998
cors_methods: str = "GET,POST,OPTIONS"
@@ -89,22 +108,3 @@ def parse_cors_origin(cls, v):
89108
def parse_cors_methods(cls, v):
90109
"""Parse CORS methods."""
91110
return [method.strip() for method in v.split(",")]
92-
93-
@property
94-
def reader_connection_string(self):
95-
"""Create reader psql connection string."""
96-
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}"
97-
98-
@property
99-
def writer_connection_string(self):
100-
"""Create writer psql connection string."""
101-
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}"
102-
103-
@property
104-
def testing_connection_string(self):
105-
"""Create testing psql connection string."""
106-
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb"
107-
108-
model_config = SettingsConfigDict(
109-
**{**ApiSettings.model_config, **{"env_nested_delimiter": "__"}}
110-
)

stac_fastapi/pgstac/db.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
NotFoundError,
2626
)
2727

28+
from stac_fastapi.pgstac.config import PostgresSettings
29+
2830

2931
async def con_init(conn):
3032
"""Use orjson for json returns."""
@@ -46,19 +48,25 @@ async def con_init(conn):
4648

4749

4850
async def connect_to_db(
49-
app: FastAPI, get_conn: Optional[ConnectionGetter] = None
51+
app: FastAPI,
52+
get_conn: Optional[ConnectionGetter] = None,
53+
postgres_settings: Optional[PostgresSettings] = None,
5054
) -> None:
5155
"""Create connection pools & connection retriever on application."""
52-
settings = app.state.settings
53-
if app.state.settings.testing:
54-
readpool = writepool = settings.testing_connection_string
56+
app_settings = app.state.settings
57+
58+
if not postgres_settings:
59+
postgres_settings = PostgresSettings()
60+
61+
if app_settings.testing:
62+
readpool = writepool = postgres_settings.testing_connection_string
5563
else:
56-
readpool = settings.reader_connection_string
57-
writepool = settings.writer_connection_string
64+
readpool = postgres_settings.reader_connection_string
65+
writepool = postgres_settings.writer_connection_string
5866

5967
db = DB()
60-
app.state.readpool = await db.create_pool(readpool, settings)
61-
app.state.writepool = await db.create_pool(writepool, settings)
68+
app.state.readpool = await db.create_pool(readpool, postgres_settings)
69+
app.state.writepool = await db.create_pool(writepool, postgres_settings)
6270
app.state.get_connection = get_conn if get_conn else get_connection
6371

6472

tests/api/test_api.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from stac_fastapi.extensions.core.fields import FieldsConformanceClasses
2121
from stac_fastapi.types import stac as stac_types
2222

23+
from stac_fastapi.pgstac.config import PostgresSettings
2324
from stac_fastapi.pgstac.core import CoreCrudClient, Settings
2425
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
2526
from stac_fastapi.pgstac.transactions import TransactionsClient
@@ -720,13 +721,16 @@ async def get_collection(
720721
return await super().get_collection(collection_id, request=request, **kwargs)
721722

722723
settings = Settings(
724+
testing=True,
725+
)
726+
727+
postgres_settings = PostgresSettings(
723728
postgres_user=database.user,
724729
postgres_pass=database.password,
725730
postgres_host_reader=database.host,
726731
postgres_host_writer=database.host,
727732
postgres_port=database.port,
728733
postgres_dbname=database.dbname,
729-
testing=True,
730734
)
731735

732736
extensions = [
@@ -751,7 +755,7 @@ async def get_collection(
751755
collections_get_request_model=collection_search_extension.GET,
752756
)
753757
app = api.app
754-
await connect_to_db(app)
758+
await connect_to_db(app, postgres_settings=postgres_settings)
755759
try:
756760
async with AsyncClient(transport=ASGITransport(app=app)) as client:
757761
response = await client.post(
@@ -786,15 +790,17 @@ async def test_no_extension(
786790
loader.load_items(os.path.join(DATA_DIR, "test_item.json"))
787791

788792
settings = Settings(
793+
testing=True,
794+
use_api_hydrate=hydrate,
795+
enable_response_models=validation,
796+
)
797+
postgres_settings = PostgresSettings(
789798
postgres_user=database.user,
790799
postgres_pass=database.password,
791800
postgres_host_reader=database.host,
792801
postgres_host_writer=database.host,
793802
postgres_port=database.port,
794803
postgres_dbname=database.dbname,
795-
testing=True,
796-
use_api_hydrate=hydrate,
797-
enable_response_models=validation,
798804
)
799805
extensions = []
800806
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
@@ -805,7 +811,7 @@ async def test_no_extension(
805811
search_post_request_model=post_request_model,
806812
)
807813
app = api.app
808-
await connect_to_db(app)
814+
await connect_to_db(app, postgres_settings=postgres_settings)
809815
try:
810816
async with AsyncClient(transport=ASGITransport(app=app)) as client:
811817
landing = await client.get("http://test/")

tests/clients/test_postgres.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
import pytest
88
from fastapi import Request
9+
from pydantic import ValidationError
910
from stac_pydantic import Collection, Item
1011

12+
from stac_fastapi.pgstac.config import PostgresSettings
1113
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db, get_connection
1214

1315
# from tests.conftest import MockStarletteRequest
@@ -523,6 +525,28 @@ async def test_create_bulk_items_id_mismatch(
523525
# assert item.collection == coll.id
524526

525527

528+
async def test_db_setup_works_with_env_vars(api_client, database, monkeypatch):
529+
"""Test that the application starts successfully if the POSTGRES_* environment variables are set"""
530+
monkeypatch.setenv("POSTGRES_USER", database.user)
531+
monkeypatch.setenv("POSTGRES_PASS", database.password)
532+
monkeypatch.setenv("POSTGRES_HOST_READER", database.host)
533+
monkeypatch.setenv("POSTGRES_HOST_WRITER", database.host)
534+
monkeypatch.setenv("POSTGRES_PORT", str(database.port))
535+
monkeypatch.setenv("POSTGRES_DBNAME", database.dbname)
536+
537+
await connect_to_db(api_client.app)
538+
await close_db_connection(api_client.app)
539+
540+
541+
async def test_db_setup_fails_without_env_vars(api_client):
542+
"""Test that the application fails to start if database environment variables are not set."""
543+
try:
544+
await connect_to_db(api_client.app)
545+
except ValidationError:
546+
await close_db_connection(api_client.app)
547+
pytest.raises(ValidationError)
548+
549+
526550
@asynccontextmanager
527551
async def custom_get_connection(
528552
request: Request,
@@ -536,12 +560,21 @@ async def custom_get_connection(
536560

537561
class TestDbConnect:
538562
@pytest.fixture
539-
async def app(self, api_client):
563+
async def app(self, api_client, database):
540564
"""
541565
app fixture override to setup app with a customized db connection getter
542566
"""
567+
postgres_settings = PostgresSettings(
568+
postgres_user=database.user,
569+
postgres_pass=database.password,
570+
postgres_host_reader=database.host,
571+
postgres_host_writer=database.host,
572+
postgres_port=database.port,
573+
postgres_dbname=database.dbname,
574+
)
575+
543576
logger.debug("Customizing app setup")
544-
await connect_to_db(api_client.app, custom_get_connection)
577+
await connect_to_db(api_client.app, custom_get_connection, postgres_settings)
545578
yield api_client.app
546579
await close_db_connection(api_client.app)
547580

tests/conftest.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from stac_fastapi.extensions.third_party import BulkTransactionExtension
4242
from stac_pydantic import Collection, Item
4343

44-
from stac_fastapi.pgstac.config import Settings
44+
from stac_fastapi.pgstac.config import PostgresSettings, Settings
4545
from stac_fastapi.pgstac.core import CoreCrudClient
4646
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
4747
from stac_fastapi.pgstac.extensions import QueryExtension
@@ -111,18 +111,12 @@ async def pgstac(database):
111111
],
112112
scope="session",
113113
)
114-
def api_client(request, database):
114+
def api_client(request):
115115
hydrate, prefix, response_model = request.param
116116
api_settings = Settings(
117-
postgres_user=database.user,
118-
postgres_pass=database.password,
119-
postgres_host_reader=database.host,
120-
postgres_host_writer=database.host,
121-
postgres_port=database.port,
122-
postgres_dbname=database.dbname,
123-
use_api_hydrate=hydrate,
124117
enable_response_models=response_model,
125118
testing=True,
119+
use_api_hydrate=hydrate,
126120
)
127121

128122
api_settings.openapi_url = prefix + api_settings.openapi_url
@@ -203,11 +197,19 @@ def api_client(request, database):
203197

204198

205199
@pytest.fixture(scope="function")
206-
async def app(api_client):
200+
async def app(api_client, database):
201+
postgres_settings = PostgresSettings(
202+
postgres_user=database.user,
203+
postgres_pass=database.password,
204+
postgres_host_reader=database.host,
205+
postgres_host_writer=database.host,
206+
postgres_port=database.port,
207+
postgres_dbname=database.dbname,
208+
)
207209
logger.info("Creating app Fixture")
208210
time.time()
209211
app = api_client.app
210-
await connect_to_db(app)
212+
await connect_to_db(app, postgres_settings=postgres_settings)
211213

212214
yield app
213215

@@ -290,14 +292,8 @@ async def load_test2_item(app_client, load_test_data, load_test2_collection):
290292
@pytest.fixture(
291293
scope="session",
292294
)
293-
def api_client_no_ext(database):
295+
def api_client_no_ext():
294296
api_settings = Settings(
295-
postgres_user=database.user,
296-
postgres_pass=database.password,
297-
postgres_host_reader=database.host,
298-
postgres_host_writer=database.host,
299-
postgres_port=database.port,
300-
postgres_dbname=database.dbname,
301297
testing=True,
302298
)
303299
return StacApi(
@@ -310,11 +306,19 @@ def api_client_no_ext(database):
310306

311307

312308
@pytest.fixture(scope="function")
313-
async def app_no_ext(api_client_no_ext):
309+
async def app_no_ext(api_client_no_ext, database):
310+
postgres_settings = PostgresSettings(
311+
postgres_user=database.user,
312+
postgres_pass=database.password,
313+
postgres_host_reader=database.host,
314+
postgres_host_writer=database.host,
315+
postgres_port=database.port,
316+
postgres_dbname=database.dbname,
317+
)
314318
logger.info("Creating app Fixture")
315319
time.time()
316320
app = api_client_no_ext.app
317-
await connect_to_db(app)
321+
await connect_to_db(app, postgres_settings=postgres_settings)
318322

319323
yield app
320324

0 commit comments

Comments
 (0)