Skip to content

Commit 3c5a913

Browse files
cctdanielReisen
andauthored
feat(hermes): add sse endpoint (#1425)
* add initial sse code * fix typo * add more error handling * fix formatting * revert import format * add error handling for nonexistent price feeds in the middle of sub * refactor * format * add comment * Update hermes/src/api/sse.rs Co-authored-by: Reisen <Reisen@users.noreply.github.com> * refactor * bump --------- Co-authored-by: Reisen <Reisen@users.noreply.github.com>
1 parent e1f9783 commit 3c5a913

File tree

6 files changed

+189
-6
lines changed

6 files changed

+189
-6
lines changed

hermes/Cargo.lock

Lines changed: 5 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hermes/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "hermes"
3-
version = "0.5.3"
3+
version = "0.5.4"
44
description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
55
edition = "2021"
66

@@ -42,6 +42,7 @@ serde_wormhole = { git = "https://github.com/wormhole-foundation/wormhol
4242
sha3 = { version = "0.10.4" }
4343
strum = { version = "0.24.1", features = ["derive"] }
4444
tokio = { version = "1.26.0", features = ["full"] }
45+
tokio-stream = { version = "0.1.15", features = ["full"] }
4546
tonic = { version = "0.10.1", features = ["tls"] }
4647
tower-http = { version = "0.4.0", features = ["cors"] }
4748
tracing = { version = "0.1.37", features = ["log"] }

hermes/src/api.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use {
2323
mod doc_examples;
2424
mod metrics_middleware;
2525
mod rest;
26+
mod sse;
2627
pub mod types;
2728
mod ws;
2829

@@ -143,6 +144,10 @@ pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> {
143144
.route("/api/latest_price_feeds", get(rest::latest_price_feeds))
144145
.route("/api/latest_vaas", get(rest::latest_vaas))
145146
.route("/api/price_feed_ids", get(rest::price_feed_ids))
147+
.route(
148+
"/v2/updates/price/stream",
149+
get(sse::price_stream_sse_handler),
150+
)
146151
.route("/v2/updates/price/latest", get(rest::latest_price_updates))
147152
.route(
148153
"/v2/updates/price/:publish_time",

hermes/src/api/rest.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod price_feed_ids;
2121
mod ready;
2222
mod v2;
2323

24+
2425
pub use {
2526
get_price_feed::*,
2627
get_vaa::*,
@@ -38,6 +39,7 @@ pub use {
3839
},
3940
};
4041

42+
#[derive(Debug)]
4143
pub enum RestError {
4244
BenchmarkPriceNotUnique,
4345
UpdateDataNotFound,

hermes/src/api/sse.rs

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
use {
2+
crate::{
3+
aggregate::{
4+
AggregationEvent,
5+
RequestTime,
6+
},
7+
api::{
8+
rest::{
9+
verify_price_ids_exist,
10+
RestError,
11+
},
12+
types::{
13+
BinaryPriceUpdate,
14+
EncodingType,
15+
ParsedPriceUpdate,
16+
PriceIdInput,
17+
PriceUpdate,
18+
},
19+
ApiState,
20+
},
21+
},
22+
anyhow::Result,
23+
axum::{
24+
extract::State,
25+
response::sse::{
26+
Event,
27+
KeepAlive,
28+
Sse,
29+
},
30+
},
31+
futures::Stream,
32+
pyth_sdk::PriceIdentifier,
33+
serde::Deserialize,
34+
serde_qs::axum::QsQuery,
35+
std::convert::Infallible,
36+
tokio::sync::broadcast,
37+
tokio_stream::{
38+
wrappers::BroadcastStream,
39+
StreamExt as _,
40+
},
41+
utoipa::IntoParams,
42+
};
43+
44+
#[derive(Debug, Deserialize, IntoParams)]
45+
#[into_params(parameter_in = Query)]
46+
pub struct StreamPriceUpdatesQueryParams {
47+
/// Get the most recent price update for this set of price feed ids.
48+
///
49+
/// This parameter can be provided multiple times to retrieve multiple price updates,
50+
/// for example see the following query string:
51+
///
52+
/// ```
53+
/// ?ids[]=a12...&ids[]=b4c...
54+
/// ```
55+
#[param(rename = "ids[]")]
56+
#[param(example = "e62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43")]
57+
ids: Vec<PriceIdInput>,
58+
59+
/// If true, include the parsed price update in the `parsed` field of each returned feed.
60+
#[serde(default)]
61+
encoding: EncodingType,
62+
63+
/// If true, include the parsed price update in the `parsed` field of each returned feed.
64+
#[serde(default = "default_true")]
65+
parsed: bool,
66+
}
67+
68+
fn default_true() -> bool {
69+
true
70+
}
71+
72+
#[utoipa::path(
73+
get,
74+
path = "/v2/updates/price/stream",
75+
responses(
76+
(status = 200, description = "Price updates retrieved successfully", body = PriceUpdate),
77+
(status = 404, description = "Price ids not found", body = String)
78+
),
79+
params(StreamPriceUpdatesQueryParams)
80+
)]
81+
/// SSE route handler for streaming price updates.
82+
pub async fn price_stream_sse_handler(
83+
State(state): State<ApiState>,
84+
QsQuery(params): QsQuery<StreamPriceUpdatesQueryParams>,
85+
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RestError> {
86+
let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(Into::into).collect();
87+
88+
verify_price_ids_exist(&state, &price_ids).await?;
89+
90+
// Clone the update_tx receiver to listen for new price updates
91+
let update_rx: broadcast::Receiver<AggregationEvent> = state.update_tx.subscribe();
92+
93+
// Convert the broadcast receiver into a Stream
94+
let stream = BroadcastStream::new(update_rx);
95+
96+
let sse_stream = stream.then(move |message| {
97+
let state_clone = state.clone(); // Clone again to use inside the async block
98+
let price_ids_clone = price_ids.clone(); // Clone again for use inside the async block
99+
async move {
100+
match message {
101+
Ok(event) => {
102+
match handle_aggregation_event(
103+
event,
104+
state_clone,
105+
price_ids_clone,
106+
params.encoding,
107+
params.parsed,
108+
)
109+
.await
110+
{
111+
Ok(price_update) => Ok(Event::default().json_data(price_update).unwrap()),
112+
Err(e) => Ok(error_event(e)),
113+
}
114+
}
115+
Err(e) => Ok(error_event(e)),
116+
}
117+
}
118+
});
119+
120+
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
121+
}
122+
123+
async fn handle_aggregation_event(
124+
event: AggregationEvent,
125+
state: ApiState,
126+
mut price_ids: Vec<PriceIdentifier>,
127+
encoding: EncodingType,
128+
parsed: bool,
129+
) -> Result<PriceUpdate> {
130+
// We check for available price feed ids to ensure that the price feed ids provided exists since price feeds can be removed.
131+
let available_price_feed_ids = crate::aggregate::get_price_feed_ids(&*state.state).await;
132+
133+
price_ids.retain(|price_feed_id| available_price_feed_ids.contains(price_feed_id));
134+
135+
let price_feeds_with_update_data = crate::aggregate::get_price_feeds_with_update_data(
136+
&*state.state,
137+
&price_ids,
138+
RequestTime::AtSlot(event.slot()),
139+
)
140+
.await?;
141+
let price_update_data = price_feeds_with_update_data.update_data;
142+
let encoded_data: Vec<String> = price_update_data
143+
.into_iter()
144+
.map(|data| encoding.encode_str(&data))
145+
.collect();
146+
let binary_price_update = BinaryPriceUpdate {
147+
encoding,
148+
data: encoded_data,
149+
};
150+
let parsed_price_updates: Option<Vec<ParsedPriceUpdate>> = if parsed {
151+
Some(
152+
price_feeds_with_update_data
153+
.price_feeds
154+
.into_iter()
155+
.map(|price_feed| price_feed.into())
156+
.collect(),
157+
)
158+
} else {
159+
None
160+
};
161+
162+
163+
Ok(PriceUpdate {
164+
binary: binary_price_update,
165+
parsed: parsed_price_updates,
166+
})
167+
}
168+
169+
fn error_event<E: std::fmt::Debug>(e: E) -> Event {
170+
Event::default()
171+
.event("error")
172+
.data(format!("Error receiving update: {:?}", e))
173+
}

hermes/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ mod state;
2828

2929
lazy_static! {
3030
/// A static exit flag to indicate to running threads that we're shutting down. This is used to
31-
/// gracefully shutdown the application.
31+
/// gracefully shut down the application.
3232
///
3333
/// We make this global based on the fact the:
3434
/// - The `Sender` side does not rely on any async runtime.
3535
/// - Exit logic doesn't really require carefully threading this value through the app.
3636
/// - The `Receiver` side of a watch channel performs the detection based on if the change
3737
/// happened after the subscribe, so it means all listeners should always be notified
38-
/// currectly.
38+
/// correctly.
3939
pub static ref EXIT: watch::Sender<bool> = watch::channel(false).0;
4040
}
4141

0 commit comments

Comments
 (0)