1
1
import logging
2
2
from datetime import datetime
3
3
from decimal import Decimal
4
- from typing import Any , Mapping , Optional
4
+ from typing import Any , Callable , Mapping , Optional , Union
5
5
6
6
try :
7
7
import psycopg2
21
21
ClientConnectSuccessCallback ,
22
22
SinkBatch ,
23
23
)
24
+ from quixstreams .sinks .base .item import SinkItem
24
25
25
26
__all__ = ("PostgreSQLSink" , "PostgreSQLSinkException" )
26
27
@@ -58,7 +59,8 @@ def __init__(
58
59
dbname : str ,
59
60
user : str ,
60
61
password : str ,
61
- table_name : str ,
62
+ table_name : Union [Callable [[SinkItem ], str ], str ],
63
+ schema_name : str = "public" ,
62
64
schema_auto_update : bool = True ,
63
65
connection_timeout_seconds : int = 30 ,
64
66
statement_timeout_seconds : int = 30 ,
@@ -72,9 +74,13 @@ def __init__(
72
74
:param host: PostgreSQL server address.
73
75
:param port: PostgreSQL server port.
74
76
:param dbname: PostgreSQL database name.
75
- :param user: Database user name .
77
+ :param user: Database username .
76
78
:param password: Database user password.
77
- :param table_name: PostgreSQL table name.
79
+ :param table_name: PostgreSQL table name as either a string or a callable which
80
+ receives a SinkItem and returns a string.
81
+ :param schema_name: The schema name. Schemas are a way of organizing tables and
82
+ not related to the table data, referenced as `<schema_name>.<table_name>`.
83
+ PostrgeSQL uses "public" by default under the hood.
78
84
:param schema_auto_update: Automatically update the schema when new columns are detected.
79
85
:param connection_timeout_seconds: Timeout for connection.
80
86
:param statement_timeout_seconds: Timeout for DDL operations such as table
@@ -91,9 +97,10 @@ def __init__(
91
97
on_client_connect_success = on_client_connect_success ,
92
98
on_client_connect_failure = on_client_connect_failure ,
93
99
)
94
-
95
- self .table_name = table_name
96
- self .schema_auto_update = schema_auto_update
100
+ self ._table_name = _table_name_setter (table_name )
101
+ self ._tables = set ()
102
+ self ._schema_name = schema_name
103
+ self ._schema_auto_update = schema_auto_update
97
104
options = kwargs .pop ("options" , "" )
98
105
if "statement_timeout" not in options :
99
106
options = f"{ options } -c statement_timeout={ statement_timeout_seconds } s"
@@ -111,38 +118,48 @@ def __init__(
111
118
112
119
def setup (self ):
113
120
self ._client = psycopg2 .connect (** self ._client_settings )
114
-
115
- # Initialize table if schema_auto_update is enabled
116
- if self .schema_auto_update :
117
- self ._init_table ()
121
+ self ._create_schema ()
118
122
119
123
def write (self , batch : SinkBatch ):
120
- rows = []
121
- cols_types = {}
122
-
124
+ tables = {}
123
125
for item in batch :
126
+ table = tables .setdefault (
127
+ self ._table_name (item ), {"rows" : [], "cols_types" : {}}
128
+ )
124
129
row = {}
125
130
if item .key is not None :
126
131
key_type = type (item .key )
127
- cols_types .setdefault (_KEY_COLUMN_NAME , key_type )
132
+ table [ " cols_types" ] .setdefault (_KEY_COLUMN_NAME , key_type )
128
133
row [_KEY_COLUMN_NAME ] = item .key
129
134
130
135
for key , value in item .value .items ():
131
136
if value is not None :
132
- cols_types .setdefault (key , type (value ))
137
+ table [ " cols_types" ] .setdefault (key , type (value ))
133
138
row [key ] = value
134
139
135
140
row [_TIMESTAMP_COLUMN_NAME ] = datetime .fromtimestamp (item .timestamp / 1000 )
136
- rows .append (row )
141
+ table [ " rows" ] .append (row )
137
142
138
143
try :
139
144
with self ._client :
140
- if self .schema_auto_update :
141
- self ._add_new_columns (cols_types )
142
- self ._insert_rows (rows )
145
+ for name , values in tables .items ():
146
+ if self ._schema_auto_update :
147
+ self ._create_table (name )
148
+ self ._add_new_columns (name , values ["cols_types" ])
149
+ self ._insert_rows (name , values ["rows" ])
143
150
except psycopg2 .Error as e :
144
151
self ._client .rollback ()
145
152
raise PostgreSQLSinkException (f"Failed to write batch: { str (e )} " ) from e
153
+ table_counts = {table : len (values ["rows" ]) for table , values in tables .items ()}
154
+ schema_log = (
155
+ " "
156
+ if self ._schema_name == "public"
157
+ else f" for schema '{ self ._schema_name } ' "
158
+ )
159
+ logger .info (
160
+ f"Successfully wrote records{ schema_log } to tables; "
161
+ f"table row counts: { table_counts } "
162
+ )
146
163
147
164
def add (
148
165
self ,
@@ -169,7 +186,17 @@ def add(
169
186
offset = offset ,
170
187
)
171
188
172
- def _init_table (self ):
189
+ def _create_schema (self ):
190
+ query = sql .SQL ("CREATE SCHEMA IF NOT EXISTS {}" ).format (
191
+ sql .Identifier (self ._schema_name )
192
+ )
193
+
194
+ with self ._client .cursor () as cursor :
195
+ cursor .execute (query )
196
+
197
+ def _create_table (self , table_name : str ):
198
+ if table_name in self ._tables :
199
+ return
173
200
query = sql .SQL (
174
201
"""
175
202
CREATE TABLE IF NOT EXISTS {table} (
@@ -178,15 +205,15 @@ def _init_table(self):
178
205
)
179
206
"""
180
207
).format (
181
- table = sql .Identifier (self .table_name ),
208
+ table = sql .Identifier (self ._schema_name , table_name ),
182
209
timestamp_col = sql .Identifier (_TIMESTAMP_COLUMN_NAME ),
183
210
key_col = sql .Identifier (_KEY_COLUMN_NAME ),
184
211
)
185
212
186
213
with self ._client .cursor () as cursor :
187
214
cursor .execute (query )
188
215
189
- def _add_new_columns (self , columns : dict [str , type ]) -> None :
216
+ def _add_new_columns (self , table_name : str , columns : dict [str , type ]) -> None :
190
217
for col_name , py_type in columns .items ():
191
218
postgres_col_type = _POSTGRES_TYPES_MAP .get (py_type )
192
219
if postgres_col_type is None :
@@ -200,15 +227,15 @@ def _add_new_columns(self, columns: dict[str, type]) -> None:
200
227
ADD COLUMN IF NOT EXISTS {column} {col_type}
201
228
"""
202
229
).format (
203
- table = sql .Identifier (self .table_name ),
230
+ table = sql .Identifier (self ._schema_name , table_name ),
204
231
column = sql .Identifier (col_name ),
205
232
col_type = sql .SQL (postgres_col_type ),
206
233
)
207
234
208
235
with self ._client .cursor () as cursor :
209
236
cursor .execute (query )
210
237
211
- def _insert_rows (self , rows : list [dict ]) -> None :
238
+ def _insert_rows (self , table_name : str , rows : list [dict ]) -> None :
212
239
if not rows :
213
240
return
214
241
@@ -218,9 +245,17 @@ def _insert_rows(self, rows: list[dict]) -> None:
218
245
values = [[row .get (col , None ) for col in columns ] for row in rows ]
219
246
220
247
query = sql .SQL ("INSERT INTO {table} ({columns}) VALUES %s" ).format (
221
- table = sql .Identifier (self .table_name ),
248
+ table = sql .Identifier (self ._schema_name , table_name ),
222
249
columns = sql .SQL (", " ).join (map (sql .Identifier , columns )),
223
250
)
224
251
225
252
with self ._client .cursor () as cursor :
226
253
execute_values (cursor , query , values )
254
+
255
+
256
+ def _table_name_setter (
257
+ table_name : Union [Callable [[SinkItem ], str ], str ],
258
+ ) -> Callable [[SinkItem ], str ]:
259
+ if isinstance (table_name , str ):
260
+ return lambda sink_item : table_name
261
+ return table_name
0 commit comments