1
- use anyhow:: { anyhow, Result } ;
1
+ use anyhow:: { anyhow, Context , Result } ;
2
2
use native_tls:: TlsConnector ;
3
3
use postgres_native_tls:: MakeTlsConnector ;
4
4
use spin_world:: async_trait;
5
5
use spin_world:: spin:: postgres:: postgres:: {
6
6
self as v3, Column , DbDataType , DbValue , ParameterValue , RowSet ,
7
7
} ;
8
8
use tokio_postgres:: types:: Type ;
9
- use tokio_postgres:: { config:: SslMode , types:: ToSql , Row } ;
10
- use tokio_postgres:: { Client as TokioClient , NoTls , Socket } ;
9
+ use tokio_postgres:: { config:: SslMode , types:: ToSql , NoTls , Row } ;
10
+
11
+ const CONNECTION_POOL_SIZE : usize = 64 ;
11
12
12
13
#[ async_trait]
13
- pub trait Client {
14
- async fn build_client ( address : & str ) -> Result < Self >
15
- where
16
- Self : Sized ;
14
+ pub trait ClientFactory : Send + Sync {
15
+ type Client : Client + Send + Sync + ' static ;
16
+ fn new ( ) -> Self ;
17
+ async fn build_client ( & mut self , address : & str ) -> Result < Self :: Client > ;
18
+ }
19
+
20
+ pub struct PooledTokioClientFactory {
21
+ pools : std:: collections:: HashMap < String , deadpool_postgres:: Pool > ,
22
+ }
23
+
24
+ #[ async_trait]
25
+ impl ClientFactory for PooledTokioClientFactory {
26
+ type Client = deadpool_postgres:: Object ;
27
+
28
+ fn new ( ) -> Self {
29
+ Self {
30
+ pools : Default :: default ( ) ,
31
+ }
32
+ }
33
+
34
+ async fn build_client ( & mut self , address : & str ) -> Result < Self :: Client > {
35
+ let pool_entry = self . pools . entry ( address. to_owned ( ) ) ;
36
+ let pool = match pool_entry {
37
+ std:: collections:: hash_map:: Entry :: Occupied ( entry) => entry. into_mut ( ) ,
38
+ std:: collections:: hash_map:: Entry :: Vacant ( entry) => {
39
+ let pool = create_connection_pool ( address)
40
+ . context ( "establishing PostgreSQL connection pool" ) ?;
41
+ entry. insert ( pool)
42
+ }
43
+ } ;
44
+
45
+ Ok ( pool. get ( ) . await ?)
46
+ }
47
+ }
48
+
49
+ fn create_connection_pool ( address : & str ) -> Result < deadpool_postgres:: Pool > {
50
+ let config = address
51
+ . parse :: < tokio_postgres:: Config > ( )
52
+ . context ( "parsing Postgres connection string" ) ?;
53
+
54
+ tracing:: debug!( "Build new connection: {}" , address) ;
17
55
56
+ // TODO: This is slower but safer. Is it the right tradeoff?
57
+ // https://docs.rs/deadpool-postgres/latest/deadpool_postgres/enum.RecyclingMethod.html
58
+ let mgr_config = deadpool_postgres:: ManagerConfig {
59
+ recycling_method : deadpool_postgres:: RecyclingMethod :: Clean ,
60
+ } ;
61
+
62
+ let mgr = if config. get_ssl_mode ( ) == SslMode :: Disable {
63
+ deadpool_postgres:: Manager :: from_config ( config, NoTls , mgr_config)
64
+ } else {
65
+ let builder = TlsConnector :: builder ( ) ;
66
+ let connector = MakeTlsConnector :: new ( builder. build ( ) ?) ;
67
+ deadpool_postgres:: Manager :: from_config ( config, connector, mgr_config)
68
+ } ;
69
+
70
+ // TODO: what is our max size heuristic? Should this be passed in soe that different
71
+ // hosts can manage it according to their needs? Will a plain number suffice for
72
+ // sophisticated hosts anyway?
73
+ let pool = deadpool_postgres:: Pool :: builder ( mgr)
74
+ . max_size ( CONNECTION_POOL_SIZE )
75
+ . build ( )
76
+ . context ( "building Postgres connection pool" ) ?;
77
+
78
+ Ok ( pool)
79
+ }
80
+
81
+ #[ async_trait]
82
+ pub trait Client {
18
83
async fn execute (
19
84
& self ,
20
85
statement : String ,
@@ -29,28 +94,7 @@ pub trait Client {
29
94
}
30
95
31
96
#[ async_trait]
32
- impl Client for TokioClient {
33
- async fn build_client ( address : & str ) -> Result < Self >
34
- where
35
- Self : Sized ,
36
- {
37
- let config = address. parse :: < tokio_postgres:: Config > ( ) ?;
38
-
39
- tracing:: debug!( "Build new connection: {}" , address) ;
40
-
41
- if config. get_ssl_mode ( ) == SslMode :: Disable {
42
- let ( client, connection) = config. connect ( NoTls ) . await ?;
43
- spawn_connection ( connection) ;
44
- Ok ( client)
45
- } else {
46
- let builder = TlsConnector :: builder ( ) ;
47
- let connector = MakeTlsConnector :: new ( builder. build ( ) ?) ;
48
- let ( client, connection) = config. connect ( connector) . await ?;
49
- spawn_connection ( connection) ;
50
- Ok ( client)
51
- }
52
- }
53
-
97
+ impl Client for deadpool_postgres:: Object {
54
98
async fn execute (
55
99
& self ,
56
100
statement : String ,
@@ -67,7 +111,8 @@ impl Client for TokioClient {
67
111
. map ( |b| b. as_ref ( ) as & ( dyn ToSql + Sync ) )
68
112
. collect ( ) ;
69
113
70
- self . execute ( & statement, params_refs. as_slice ( ) )
114
+ self . as_ref ( )
115
+ . execute ( & statement, params_refs. as_slice ( ) )
71
116
. await
72
117
. map_err ( |e| v3:: Error :: QueryFailed ( format ! ( "{e:?}" ) ) )
73
118
}
@@ -89,6 +134,7 @@ impl Client for TokioClient {
89
134
. collect ( ) ;
90
135
91
136
let results = self
137
+ . as_ref ( )
92
138
. query ( & statement, params_refs. as_slice ( ) )
93
139
. await
94
140
. map_err ( |e| v3:: Error :: QueryFailed ( format ! ( "{e:?}" ) ) ) ?;
@@ -111,17 +157,6 @@ impl Client for TokioClient {
111
157
}
112
158
}
113
159
114
- fn spawn_connection < T > ( connection : tokio_postgres:: Connection < Socket , T > )
115
- where
116
- T : tokio_postgres:: tls:: TlsStream + std:: marker:: Unpin + std:: marker:: Send + ' static ,
117
- {
118
- tokio:: spawn ( async move {
119
- if let Err ( e) = connection. await {
120
- tracing:: error!( "Postgres connection error: {}" , e) ;
121
- }
122
- } ) ;
123
- }
124
-
125
160
fn to_sql_parameter ( value : & ParameterValue ) -> Result < Box < dyn ToSql + Send + Sync > > {
126
161
match value {
127
162
ParameterValue :: Boolean ( v) => Ok ( Box :: new ( * v) ) ,
0 commit comments