Skip to content

Commit ebf24c5

Browse files
committed
feat(outbound-pg): Reuse connection during request lifecycle
Also, pg_backend_pid endpoint into outbound-pg example. Before the fix: $ curl -i localhost:3000/pg_backend_pid HTTP/1.1 500 Internal Server Error content-length: 0 date: Tue, 27 Sep 2022 14:03:50 GMT thread '<unnamed>' panicked at 'assertion failed: `(left == right)` left: `592913`, right: `592914`', src/lib.rs:112:5 After the fix: $ curl -i localhost:3000/pg_backend_pid HTTP/1.1 200 OK content-length: 23 date: Tue, 27 Sep 2022 14:07:14 GMT pg_backend_pid: 595194 Ideally, this fix has to be covered by an integration test rather than manual testing through examples, but the testing environment has to be set up first. Refs: #667. Signed-off-by: Konstantin Shabanov <mail@etehtsea.me>
1 parent 01118c5 commit ebf24c5

File tree

3 files changed

+58
-21
lines changed

3 files changed

+58
-21
lines changed

crates/outbound-pg/src/lib.rs

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use anyhow::anyhow;
22
use outbound_pg::*;
3+
use std::collections::HashMap;
34
use tokio_postgres::{
4-
tls::NoTlsStream,
55
types::{ToSql, Type},
6-
Connection, NoTls, Row, Socket,
6+
Client, NoTls, Row,
77
};
88

99
pub use outbound_pg::add_to_linker;
@@ -16,8 +16,9 @@ use wit_bindgen_wasmtime::{async_trait, wasmtime::Linker};
1616
wit_bindgen_wasmtime::export!({paths: ["../../wit/ephemeral/outbound-pg.wit"], async: *});
1717

1818
/// A simple implementation to support outbound pg connection
19-
#[derive(Default, Clone)]
20-
pub struct OutboundPg;
19+
pub struct OutboundPg {
20+
pub connections: HashMap<String, Client>,
21+
}
2122

2223
impl HostComponent for OutboundPg {
2324
type State = Self;
@@ -33,7 +34,9 @@ impl HostComponent for OutboundPg {
3334
&self,
3435
_component: &spin_manifest::CoreComponent,
3536
) -> anyhow::Result<Self::State> {
36-
Ok(Self)
37+
let connections = std::collections::HashMap::new();
38+
39+
Ok(Self { connections })
3740
}
3841
}
3942

@@ -45,19 +48,16 @@ impl outbound_pg::OutboundPg for OutboundPg {
4548
statement: &str,
4649
params: Vec<ParameterValue<'_>>,
4750
) -> Result<u64, PgError> {
48-
let (client, connection) = tokio_postgres::connect(address, NoTls)
49-
.await
50-
.map_err(|e| PgError::ConnectionFailed(format!("{:?}", e)))?;
51-
52-
spawn(connection);
53-
5451
let params: Vec<&(dyn ToSql + Sync)> = params
5552
.iter()
5653
.map(to_sql_parameter)
5754
.collect::<anyhow::Result<Vec<_>>>()
5855
.map_err(|e| PgError::ValueConversionFailed(format!("{:?}", e)))?;
5956

60-
let nrow = client
57+
let nrow = self
58+
.get_client(address)
59+
.await
60+
.map_err(|e| PgError::ConnectionFailed(format!("{:?}", e)))?
6161
.execute(statement, params.as_slice())
6262
.await
6363
.map_err(|e| PgError::QueryFailed(format!("{:?}", e)))?;
@@ -71,19 +71,16 @@ impl outbound_pg::OutboundPg for OutboundPg {
7171
statement: &str,
7272
params: Vec<ParameterValue<'_>>,
7373
) -> Result<RowSet, PgError> {
74-
let (client, connection) = tokio_postgres::connect(address, NoTls)
75-
.await
76-
.map_err(|e| PgError::ConnectionFailed(format!("{:?}", e)))?;
77-
78-
spawn(connection);
79-
8074
let params: Vec<&(dyn ToSql + Sync)> = params
8175
.iter()
8276
.map(to_sql_parameter)
8377
.collect::<anyhow::Result<Vec<_>>>()
8478
.map_err(|e| PgError::BadParameter(format!("{:?}", e)))?;
8579

86-
let results = client
80+
let results = self
81+
.get_client(address)
82+
.await
83+
.map_err(|e| PgError::ConnectionFailed(format!("{:?}", e)))?
8784
.query(statement, params.as_slice())
8885
.await
8986
.map_err(|e| PgError::QueryFailed(format!("{:?}", e)))?;
@@ -246,10 +243,26 @@ fn convert_entry(row: &Row, index: usize) -> Result<DbValue, tokio_postgres::Err
246243
Ok(value)
247244
}
248245

249-
fn spawn(connection: Connection<Socket, NoTlsStream>) {
246+
impl OutboundPg {
247+
async fn get_client(&mut self, address: &str) -> anyhow::Result<&Client> {
248+
let client = match self.connections.entry(address.to_owned()) {
249+
std::collections::hash_map::Entry::Occupied(o) => o.into_mut(),
250+
std::collections::hash_map::Entry::Vacant(v) => v.insert(build_client(address).await?),
251+
};
252+
Ok(client)
253+
}
254+
}
255+
256+
async fn build_client(address: &str) -> anyhow::Result<Client> {
257+
tracing::log::debug!("Build new connection: {}", address);
258+
259+
let (client, connection) = tokio_postgres::connect(address, NoTls).await?;
260+
250261
tokio::spawn(async move {
251262
if let Err(e) = connection.await {
252263
tracing::warn!("Postgres connection error: {}", e);
253264
}
254265
});
266+
267+
Ok(client)
255268
}

crates/trigger/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ pub fn add_default_host_components<T: Default + Send + 'static>(
138138
builder.add_host_component(outbound_redis::OutboundRedis {
139139
connections: Arc::new(RwLock::new(HashMap::new())),
140140
})?;
141-
builder.add_host_component(outbound_pg::OutboundPg)?;
141+
builder.add_host_component(outbound_pg::OutboundPg {
142+
connections: HashMap::new(),
143+
})?;
142144
Ok(())
143145
}
144146

examples/rust-outbound-pg/src/lib.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ fn process(req: Request) -> Result<Response> {
2222
match req.uri().path() {
2323
"/read" => read(req),
2424
"/write" => write(req),
25+
"/pg_backend_pid" => pg_backend_pid(req),
2526
_ => Ok(http::Response::builder()
2627
.status(404)
2728
.body(Some("Not found".into()))?),
@@ -96,6 +97,27 @@ fn write(_req: Request) -> Result<Response> {
9697
.body(Some(response.into()))?)
9798
}
9899

100+
fn pg_backend_pid(_req: Request) -> Result<Response> {
101+
let address = std::env::var(DB_URL_ENV)?;
102+
let sql = "SELECT pg_backend_pid()";
103+
104+
let get_pid = || {
105+
let rowset = pg::query(&address, sql, &[])
106+
.map_err(|e| anyhow!("Error executing Postgres query: {:?}", e))?;
107+
108+
let row = &rowset.rows[0];
109+
as_int(&row[0])
110+
};
111+
112+
assert_eq!(get_pid()?, get_pid()?);
113+
114+
let response = format!("pg_backend_pid: {}\n", get_pid()?);
115+
116+
Ok(http::Response::builder()
117+
.status(200)
118+
.body(Some(response.into()))?)
119+
}
120+
99121
fn as_owned_string(value: &pg::DbValue) -> anyhow::Result<String> {
100122
match value {
101123
pg::DbValue::Str(s) => Ok(s.to_owned()),

0 commit comments

Comments
 (0)