Skip to content

Commit ba4018f

Browse files
committed
refactor(agent): use exit flag everywhere
1 parent 1fda7fa commit ba4018f

File tree

6 files changed

+72
-111
lines changed

6 files changed

+72
-111
lines changed

src/agent.rs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,6 @@ Metrics Server:
6161
Note that there is an Oracle and Exporter for each network, but only one Local Store and Global Store.
6262
6363
################################################################################################################################## */
64-
65-
pub mod legacy_schedule;
66-
pub mod market_schedule;
67-
pub mod metrics;
68-
pub mod pythd;
69-
pub mod solana;
70-
pub mod state;
71-
pub mod store;
7264
use {
7365
self::{
7466
config::Config,
@@ -78,11 +70,33 @@ use {
7870
},
7971
anyhow::Result,
8072
futures_util::future::join_all,
73+
lazy_static::lazy_static,
8174
slog::Logger,
8275
std::sync::Arc,
83-
tokio::sync::broadcast,
76+
tokio::sync::watch,
8477
};
8578

79+
pub mod legacy_schedule;
80+
pub mod market_schedule;
81+
pub mod metrics;
82+
pub mod pythd;
83+
pub mod solana;
84+
pub mod state;
85+
pub mod store;
86+
87+
lazy_static! {
88+
/// A static exit flag to indicate to running threads that we're shutting down. This is used to
89+
/// gracefully shut down the application.
90+
///
91+
/// We make this global based on the fact the:
92+
/// - The `Sender` side does not rely on any async runtime.
93+
/// - Exit logic doesn't really require carefully threading this value through the app.
94+
/// - The `Receiver` side of a watch channel performs the detection based on if the change
95+
/// happened after the subscribe, so it means all listeners should always be notified
96+
/// correctly.
97+
pub static ref EXIT: watch::Sender<bool> = watch::channel(false).0;
98+
}
99+
86100
pub struct Agent {
87101
config: Config,
88102
}
@@ -109,10 +123,6 @@ impl Agent {
109123
// job handles
110124
let mut jhs = vec![];
111125

112-
// Create the channels
113-
// TODO: make all components listen to shutdown signal
114-
let (shutdown_tx, _) = broadcast::channel(self.config.channel_capacities.shutdown);
115-
116126
// Create the Pythd Adapter.
117127
let adapter =
118128
Arc::new(state::State::new(self.config.pythd_adapter.clone(), logger.clone()).await);
@@ -136,17 +146,13 @@ impl Agent {
136146
}
137147

138148
// Create the Notifier task for the Pythd RPC.
139-
jhs.push(tokio::spawn(notifier(
140-
adapter.clone(),
141-
shutdown_tx.subscribe(),
142-
)));
149+
jhs.push(tokio::spawn(notifier(adapter.clone())));
143150

144151
// Spawn the Pythd API Server
145152
jhs.push(tokio::spawn(rpc::run(
146153
self.config.pythd_api_server.clone(),
147154
logger.clone(),
148155
adapter.clone(),
149-
shutdown_tx.subscribe(),
150156
)));
151157

152158
// Spawn the metrics server

src/agent/metrics.rs

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,22 @@
11
use {
2-
super::state::{
3-
local::PriceInfo,
4-
State,
5-
},
6-
crate::agent::{
7-
solana::oracle::PriceEntry,
8-
store::PriceIdentifier,
9-
},
2+
super::state::{local::PriceInfo, State},
3+
crate::agent::{solana::oracle::PriceEntry, store::PriceIdentifier},
104
lazy_static::lazy_static,
115
prometheus_client::{
12-
encoding::{
13-
text::encode,
14-
EncodeLabelSet,
15-
},
16-
metrics::{
17-
counter::Counter,
18-
family::Family,
19-
gauge::Gauge,
20-
},
6+
encoding::{text::encode, EncodeLabelSet},
7+
metrics::{counter::Counter, family::Family, gauge::Gauge},
218
registry::Registry,
229
},
2310
serde::Deserialize,
2411
slog::Logger,
2512
solana_sdk::pubkey::Pubkey,
2613
std::{
2714
net::SocketAddr,
28-
sync::{
29-
atomic::AtomicU64,
30-
Arc,
31-
},
15+
sync::{atomic::AtomicU64, Arc},
3216
time::Instant,
3317
},
3418
tokio::sync::Mutex,
35-
warp::{
36-
hyper::StatusCode,
37-
reply,
38-
Filter,
39-
Rejection,
40-
Reply,
41-
},
19+
warp::{hyper::StatusCode, reply, Filter, Rejection, Reply},
4220
};
4321

4422
pub fn default_bind_address() -> SocketAddr {
@@ -68,8 +46,8 @@ lazy_static! {
6846
/// metrics.
6947
pub struct MetricsServer {
7048
pub start_time: Instant,
71-
pub logger: Logger,
72-
pub adapter: Arc<State>,
49+
pub logger: Logger,
50+
pub adapter: Arc<State>,
7351
}
7452

7553
impl MetricsServer {
@@ -105,7 +83,11 @@ impl MetricsServer {
10583
}
10684
});
10785

108-
warp::serve(metrics_route).bind(addr).await;
86+
let (_, serve) = warp::serve(metrics_route).bind_with_graceful_shutdown(addr, async {
87+
let _ = crate::agent::EXIT.subscribe().changed().await;
88+
});
89+
90+
serve.await
10991
}
11092
}
11193

@@ -169,12 +151,12 @@ pub struct PriceGlobalMetrics {
169151

170152
/// f64 is used to get u64 support. Official docs:
171153
/// https://docs.rs/prometheus-client/latest/prometheus_client/metrics/gauge/struct.Gauge.html#using-atomicu64-as-storage-and-f64-on-the-interface
172-
conf: Family<PriceGlobalLabels, Gauge<f64, AtomicU64>>,
154+
conf: Family<PriceGlobalLabels, Gauge<f64, AtomicU64>>,
173155
timestamp: Family<PriceGlobalLabels, Gauge>,
174156

175157
/// Note: the exponent is not applied to this metric
176-
prev_price: Family<PriceGlobalLabels, Gauge>,
177-
prev_conf: Family<PriceGlobalLabels, Gauge<f64, AtomicU64>>,
158+
prev_price: Family<PriceGlobalLabels, Gauge>,
159+
prev_conf: Family<PriceGlobalLabels, Gauge<f64, AtomicU64>>,
178160
prev_timestamp: Family<PriceGlobalLabels, Gauge>,
179161

180162
/// How many times this Price was updated in the global store
@@ -317,10 +299,10 @@ pub struct PriceLocalLabels {
317299
/// Metrics exposed to Prometheus by the local store for each price
318300
#[derive(Default)]
319301
pub struct PriceLocalMetrics {
320-
price: Family<PriceLocalLabels, Gauge>,
302+
price: Family<PriceLocalLabels, Gauge>,
321303
/// f64 is used to get u64 support. Official docs:
322304
/// https://docs.rs/prometheus-client/latest/prometheus_client/metrics/gauge/struct.Gauge.html#using-atomicu64-as-storage-and-f64-on-the-interface
323-
conf: Family<PriceLocalLabels, Gauge<f64, AtomicU64>>,
305+
conf: Family<PriceLocalLabels, Gauge<f64, AtomicU64>>,
324306
timestamp: Family<PriceLocalLabels, Gauge>,
325307

326308
/// How many times this price was updated in the local store

src/agent/pythd/api/rpc.rs

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,7 @@ use {
5050
net::SocketAddr,
5151
sync::Arc,
5252
},
53-
tokio::sync::{
54-
broadcast,
55-
mpsc,
56-
},
53+
tokio::sync::mpsc,
5754
warp::{
5855
ws::{
5956
Message,
@@ -430,29 +427,20 @@ impl Default for Config {
430427
}
431428
}
432429

433-
pub async fn run<S>(
434-
config: Config,
435-
logger: Logger,
436-
adapter: Arc<S>,
437-
shutdown_rx: broadcast::Receiver<()>,
438-
) where
430+
pub async fn run<S>(config: Config, logger: Logger, adapter: Arc<S>)
431+
where
439432
S: state::StateApi,
440433
S: Send,
441434
S: Sync,
442435
S: 'static,
443436
{
444-
if let Err(err) = serve(config, &logger, adapter, shutdown_rx).await {
437+
if let Err(err) = serve(config, &logger, adapter).await {
445438
error!(logger, "{}", err);
446439
debug!(logger, "error context"; "context" => format!("{:?}", err));
447440
}
448441
}
449442

450-
async fn serve<S>(
451-
config: Config,
452-
logger: &Logger,
453-
adapter: Arc<S>,
454-
mut shutdown_rx: broadcast::Receiver<()>,
455-
) -> Result<()>
443+
async fn serve<S>(config: Config, logger: &Logger, adapter: Arc<S>) -> Result<()>
456444
where
457445
S: state::StateApi,
458446
S: Send,
@@ -490,8 +478,8 @@ where
490478

491479
let (_, serve) = warp::serve(index).bind_with_graceful_shutdown(
492480
config.listen_address.as_str().parse::<SocketAddr>()?,
493-
async move {
494-
let _ = shutdown_rx.recv().await;
481+
async {
482+
let _ = crate::agent::EXIT.subscribe().changed().await;
495483
},
496484
);
497485

src/agent/state.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ mod tests {
185185
let (shutdown_tx, _) = broadcast::channel(1);
186186

187187
// Spawn Price Notifier
188-
let jh = tokio::spawn(notifier(adapter.clone(), shutdown_tx.subscribe()));
188+
let jh = tokio::spawn(notifier(adapter.clone()));
189189

190190
TestAdapter {
191191
adapter,

src/agent/state/api.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,7 @@ use {
4545
PriceStatus,
4646
},
4747
std::sync::Arc,
48-
tokio::sync::{
49-
broadcast,
50-
mpsc,
51-
},
48+
tokio::sync::mpsc,
5249
};
5350

5451
// TODO: implement Display on PriceStatus and then just call PriceStatus::to_string
@@ -172,12 +169,13 @@ pub trait StateApi {
172169
) -> Result<()>;
173170
}
174171

175-
pub async fn notifier(adapter: Arc<State>, mut shutdown_rx: broadcast::Receiver<()>) {
172+
pub async fn notifier(adapter: Arc<State>) {
176173
let mut interval = tokio::time::interval(adapter.notify_price_sched_interval_duration);
174+
let mut exit = crate::agent::EXIT.subscribe();
177175
loop {
178176
adapter.drop_closed_subscriptions().await;
179177
tokio::select! {
180-
_ = shutdown_rx.recv() => {
178+
_ = exit.changed() => {
181179
info!(adapter.logger, "shutdown signal received");
182180
return;
183181
}

src/agent/state/keypairs.rs

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,17 @@
55
use {
66
super::State,
77
crate::agent::solana::network::Network,
8-
anyhow::{
9-
Context,
10-
Result,
11-
},
8+
anyhow::{Context, Result},
129
serde::Deserialize,
1310
slog::Logger,
1411
solana_client::nonblocking::rpc_client::RpcClient,
15-
solana_sdk::{
16-
commitment_config::CommitmentConfig,
17-
signature::Keypair,
18-
signer::Signer,
19-
},
20-
std::{
21-
net::SocketAddr,
22-
sync::Arc,
23-
},
24-
tokio::{
25-
sync::RwLock,
26-
task::JoinHandle,
27-
},
12+
solana_sdk::{commitment_config::CommitmentConfig, signature::Keypair, signer::Signer},
13+
std::{net::SocketAddr, sync::Arc},
14+
tokio::{sync::RwLock, task::JoinHandle},
2815
warp::{
2916
hyper::StatusCode,
30-
reply::{
31-
self,
32-
WithStatus,
33-
},
34-
Filter,
35-
Rejection,
17+
reply::{self, WithStatus},
18+
Filter, Rejection,
3619
},
3720
};
3821

@@ -49,24 +32,24 @@ pub fn default_bind_address() -> SocketAddr {
4932
#[derive(Clone, Debug, Deserialize)]
5033
#[serde(default)]
5134
pub struct Config {
52-
primary_min_keypair_balance_sol: u64,
35+
primary_min_keypair_balance_sol: u64,
5336
secondary_min_keypair_balance_sol: u64,
54-
bind_address: SocketAddr,
37+
bind_address: SocketAddr,
5538
}
5639

5740
impl Default for Config {
5841
fn default() -> Self {
5942
Self {
60-
primary_min_keypair_balance_sol: default_min_keypair_balance_sol(),
43+
primary_min_keypair_balance_sol: default_min_keypair_balance_sol(),
6144
secondary_min_keypair_balance_sol: default_min_keypair_balance_sol(),
62-
bind_address: default_bind_address(),
45+
bind_address: default_bind_address(),
6346
}
6447
}
6548
}
6649

6750
#[derive(Default)]
6851
pub struct KeypairState {
69-
primary_current_keypair: RwLock<Option<Keypair>>,
52+
primary_current_keypair: RwLock<Option<Keypair>>,
7053
secondary_current_keypair: RwLock<Option<Keypair>>,
7154
}
7255

@@ -193,9 +176,13 @@ where
193176
}
194177
});
195178

196-
let http_api_jh = tokio::spawn(
197-
warp::serve(primary_upload_route.or(secondary_upload_route)).bind(config.bind_address),
198-
);
179+
let http_api_jh = {
180+
let (_, serve) = warp::serve(primary_upload_route.or(secondary_upload_route))
181+
.bind_with_graceful_shutdown(config.bind_address, async {
182+
let _ = crate::agent::EXIT.subscribe().changed().await;
183+
});
184+
tokio::spawn(serve)
185+
};
199186

200187
// WARNING: All jobs spawned here must report their join handles in this vec
201188
vec![http_api_jh]

0 commit comments

Comments
 (0)