Skip to content

Commit 62980fa

Browse files
authored
PostgreSQLSink: allow dynamic table name selection based on record (#867)
* allow dynamic table selection based on record * add schema option
1 parent 1504a44 commit 62980fa

File tree

2 files changed

+102
-27
lines changed

2 files changed

+102
-27
lines changed

docs/connectors/sinks/postgresql-sink.md

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,50 @@ PostgreSQLSink provides at-least-once guarantees, meaning that the same records
7575

7676
PostgreSQLSink accepts the following configuration parameters:
7777

78+
## Required
79+
7880
- `host`: The address of the PostgreSQL server.
7981
- `port`: The port of the PostgreSQL server.
8082
- `dbname`: The name of the PostgreSQL database.
8183
- `user`: The database user name.
8284
- `password`: The database user password.
83-
- `table_name`: The name of the PostgreSQL table where data will be written.
85+
- `table_name`: PostgreSQL table name as either a string or a callable which receives
86+
a `SinkItem` (from quixstreams.sinks.base.item) and returns a string.
87+
88+
89+
### Optional
90+
91+
- `schema_name`: The schema name. Schemas are a way of organizing tables and
92+
not related to the table data, referenced as `<schema_name>.<table_name>`.
93+
PostrgeSQL uses "public" by default under the hood.
8494
- `schema_auto_update`: If True, the sink will automatically update the schema by adding new columns when new fields are detected. Default: True.
95+
96+
97+
## Testing Locally
98+
99+
Rather than connect to a hosted InfluxDB3 instance, you can alternatively test your
100+
application using a local instance of Influxdb3 using Docker:
101+
102+
1. Execute in terminal:
103+
104+
```bash
105+
docker run --rm -d --name postgres \
106+
-e POSTGRES_PASSWORD=local \
107+
-e POSTGRES_USER=local \
108+
-e POSTGRES_DB=local \
109+
-p 5432:5432 \
110+
postgres
111+
```
112+
113+
2. Use the following settings for `PostgreSQLSink` to connect:
114+
115+
```python
116+
PostgreSQLSink(
117+
host="localhost",
118+
port=5432,
119+
user="local",
120+
password="local",
121+
dbname="local",
122+
table_name="<YOUR TABLE NAME>",
123+
)
124+
```

quixstreams/sinks/community/postgresql.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from datetime import datetime
33
from decimal import Decimal
4-
from typing import Any, Mapping, Optional
4+
from typing import Any, Callable, Mapping, Optional, Union
55

66
try:
77
import psycopg2
@@ -21,6 +21,7 @@
2121
ClientConnectSuccessCallback,
2222
SinkBatch,
2323
)
24+
from quixstreams.sinks.base.item import SinkItem
2425

2526
__all__ = ("PostgreSQLSink", "PostgreSQLSinkException")
2627

@@ -58,7 +59,8 @@ def __init__(
5859
dbname: str,
5960
user: str,
6061
password: str,
61-
table_name: str,
62+
table_name: Union[Callable[[SinkItem], str], str],
63+
schema_name: str = "public",
6264
schema_auto_update: bool = True,
6365
connection_timeout_seconds: int = 30,
6466
statement_timeout_seconds: int = 30,
@@ -72,9 +74,13 @@ def __init__(
7274
:param host: PostgreSQL server address.
7375
:param port: PostgreSQL server port.
7476
:param dbname: PostgreSQL database name.
75-
:param user: Database user name.
77+
:param user: Database username.
7678
:param password: Database user password.
77-
:param table_name: PostgreSQL table name.
79+
:param table_name: PostgreSQL table name as either a string or a callable which
80+
receives a SinkItem and returns a string.
81+
:param schema_name: The schema name. Schemas are a way of organizing tables and
82+
not related to the table data, referenced as `<schema_name>.<table_name>`.
83+
PostrgeSQL uses "public" by default under the hood.
7884
:param schema_auto_update: Automatically update the schema when new columns are detected.
7985
:param connection_timeout_seconds: Timeout for connection.
8086
:param statement_timeout_seconds: Timeout for DDL operations such as table
@@ -91,9 +97,10 @@ def __init__(
9197
on_client_connect_success=on_client_connect_success,
9298
on_client_connect_failure=on_client_connect_failure,
9399
)
94-
95-
self.table_name = table_name
96-
self.schema_auto_update = schema_auto_update
100+
self._table_name = _table_name_setter(table_name)
101+
self._tables = set()
102+
self._schema_name = schema_name
103+
self._schema_auto_update = schema_auto_update
97104
options = kwargs.pop("options", "")
98105
if "statement_timeout" not in options:
99106
options = f"{options} -c statement_timeout={statement_timeout_seconds}s"
@@ -111,38 +118,48 @@ def __init__(
111118

112119
def setup(self):
113120
self._client = psycopg2.connect(**self._client_settings)
114-
115-
# Initialize table if schema_auto_update is enabled
116-
if self.schema_auto_update:
117-
self._init_table()
121+
self._create_schema()
118122

119123
def write(self, batch: SinkBatch):
120-
rows = []
121-
cols_types = {}
122-
124+
tables = {}
123125
for item in batch:
126+
table = tables.setdefault(
127+
self._table_name(item), {"rows": [], "cols_types": {}}
128+
)
124129
row = {}
125130
if item.key is not None:
126131
key_type = type(item.key)
127-
cols_types.setdefault(_KEY_COLUMN_NAME, key_type)
132+
table["cols_types"].setdefault(_KEY_COLUMN_NAME, key_type)
128133
row[_KEY_COLUMN_NAME] = item.key
129134

130135
for key, value in item.value.items():
131136
if value is not None:
132-
cols_types.setdefault(key, type(value))
137+
table["cols_types"].setdefault(key, type(value))
133138
row[key] = value
134139

135140
row[_TIMESTAMP_COLUMN_NAME] = datetime.fromtimestamp(item.timestamp / 1000)
136-
rows.append(row)
141+
table["rows"].append(row)
137142

138143
try:
139144
with self._client:
140-
if self.schema_auto_update:
141-
self._add_new_columns(cols_types)
142-
self._insert_rows(rows)
145+
for name, values in tables.items():
146+
if self._schema_auto_update:
147+
self._create_table(name)
148+
self._add_new_columns(name, values["cols_types"])
149+
self._insert_rows(name, values["rows"])
143150
except psycopg2.Error as e:
144151
self._client.rollback()
145152
raise PostgreSQLSinkException(f"Failed to write batch: {str(e)}") from e
153+
table_counts = {table: len(values["rows"]) for table, values in tables.items()}
154+
schema_log = (
155+
" "
156+
if self._schema_name == "public"
157+
else f" for schema '{self._schema_name}' "
158+
)
159+
logger.info(
160+
f"Successfully wrote records{schema_log}to tables; "
161+
f"table row counts: {table_counts}"
162+
)
146163

147164
def add(
148165
self,
@@ -169,7 +186,17 @@ def add(
169186
offset=offset,
170187
)
171188

172-
def _init_table(self):
189+
def _create_schema(self):
190+
query = sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(
191+
sql.Identifier(self._schema_name)
192+
)
193+
194+
with self._client.cursor() as cursor:
195+
cursor.execute(query)
196+
197+
def _create_table(self, table_name: str):
198+
if table_name in self._tables:
199+
return
173200
query = sql.SQL(
174201
"""
175202
CREATE TABLE IF NOT EXISTS {table} (
@@ -178,15 +205,15 @@ def _init_table(self):
178205
)
179206
"""
180207
).format(
181-
table=sql.Identifier(self.table_name),
208+
table=sql.Identifier(self._schema_name, table_name),
182209
timestamp_col=sql.Identifier(_TIMESTAMP_COLUMN_NAME),
183210
key_col=sql.Identifier(_KEY_COLUMN_NAME),
184211
)
185212

186213
with self._client.cursor() as cursor:
187214
cursor.execute(query)
188215

189-
def _add_new_columns(self, columns: dict[str, type]) -> None:
216+
def _add_new_columns(self, table_name: str, columns: dict[str, type]) -> None:
190217
for col_name, py_type in columns.items():
191218
postgres_col_type = _POSTGRES_TYPES_MAP.get(py_type)
192219
if postgres_col_type is None:
@@ -200,15 +227,15 @@ def _add_new_columns(self, columns: dict[str, type]) -> None:
200227
ADD COLUMN IF NOT EXISTS {column} {col_type}
201228
"""
202229
).format(
203-
table=sql.Identifier(self.table_name),
230+
table=sql.Identifier(self._schema_name, table_name),
204231
column=sql.Identifier(col_name),
205232
col_type=sql.SQL(postgres_col_type),
206233
)
207234

208235
with self._client.cursor() as cursor:
209236
cursor.execute(query)
210237

211-
def _insert_rows(self, rows: list[dict]) -> None:
238+
def _insert_rows(self, table_name: str, rows: list[dict]) -> None:
212239
if not rows:
213240
return
214241

@@ -218,9 +245,17 @@ def _insert_rows(self, rows: list[dict]) -> None:
218245
values = [[row.get(col, None) for col in columns] for row in rows]
219246

220247
query = sql.SQL("INSERT INTO {table} ({columns}) VALUES %s").format(
221-
table=sql.Identifier(self.table_name),
248+
table=sql.Identifier(self._schema_name, table_name),
222249
columns=sql.SQL(", ").join(map(sql.Identifier, columns)),
223250
)
224251

225252
with self._client.cursor() as cursor:
226253
execute_values(cursor, query, values)
254+
255+
256+
def _table_name_setter(
257+
table_name: Union[Callable[[SinkItem], str], str],
258+
) -> Callable[[SinkItem], str]:
259+
if isinstance(table_name, str):
260+
return lambda sink_item: table_name
261+
return table_name

0 commit comments

Comments
 (0)