diff --git a/Cargo.lock b/Cargo.lock index cf4426d..b8c7da4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2920,6 +2920,7 @@ name = "pythnet-watcher" version = "0.1.0" dependencies = [ "anyhow", + "base64 0.22.1", "borsh 0.9.3", "clap", "hex", diff --git a/Cargo.toml b/Cargo.toml index eb03763..284e728 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,4 @@ wormhole-vaas-serde = "0.1.0" [dev-dependencies] serde_json = "1.0.140" +base64 = "0.22.1" diff --git a/src/main.rs b/src/main.rs index c567312..95e9793 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ use { pubsub_client::PubsubClientError, rpc_config::{RpcAccountInfoConfig, RpcProgramAccountsConfig}, rpc_filter::{Memcmp, RpcFilterType}, + rpc_response::{Response, RpcKeyedAccount}, }, solana_sdk::pubkey::Pubkey, std::{fs, str::FromStr, time::Duration}, @@ -40,6 +41,68 @@ fn find_message_pda(wormhole_pid: &Pubkey, slot: u64) -> Pubkey { .0 } +const FAILED_TO_DECODE: &str = "Failed to decode account data"; +const INVALID_UNRELIABLE_DATA_FORMAT: &str = "Invalid unreliable data format"; +const INVALID_PDA_MESSAGE: &str = "Invalid PDA message"; +const INVALID_EMITTER_CHAIN: &str = "Invalid emitter chain"; +const INVALID_ACCUMULATOR_ADDRESS: &str = "Invalid accumulator address"; + +fn decode_and_verify_update( + wormhole_pid: &Pubkey, + accumulator_address: &Pubkey, + update: Response, +) -> anyhow::Result { + if find_message_pda(wormhole_pid, update.context.slot).to_string() != update.value.pubkey { + return Err(anyhow::anyhow!(INVALID_PDA_MESSAGE)); + } + let data = update.value.account.data.decode().ok_or_else(|| { + tracing::error!( + data = ?update.value.account.data, + "Failed to decode account data", + ); + anyhow::anyhow!(FAILED_TO_DECODE) + })?; + let unreliable_data: PostedMessageUnreliableData = + BorshDeserialize::deserialize(&mut data.as_slice()).map_err(|e| { + tracing::error!( + data = ?data, + error = ?e, + "Failed to decode unreliable data", + ); + anyhow::anyhow!(format!("{}: {}", INVALID_UNRELIABLE_DATA_FORMAT, e)) + })?; + + if Chain::Pythnet != unreliable_data.emitter_chain.into() { + tracing::error!( + emitter_chain = unreliable_data.emitter_chain, + "Invalid emitter chain" + ); + return Err(anyhow::anyhow!(INVALID_EMITTER_CHAIN)); + } + + if accumulator_address != &Pubkey::from(unreliable_data.emitter_address) { + tracing::error!( + emitter_address = ?unreliable_data.emitter_address, + "Invalid accumulator address" + ); + return Err(anyhow::anyhow!(INVALID_ACCUMULATOR_ADDRESS)); + } + + Ok(unreliable_data) +} + +fn message_data_to_body(unreliable_data: &PostedMessageUnreliableData) -> Body<&RawMessage> { + Body { + timestamp: unreliable_data.submission_time, + nonce: unreliable_data.nonce, + emitter_chain: unreliable_data.emitter_chain.into(), + emitter_address: Address(unreliable_data.emitter_address), + sequence: unreliable_data.sequence, + consistency_level: unreliable_data.consistency_level, + payload: RawMessage::new(unreliable_data.payload.as_slice()), + } +} + async fn run_listener(input: RunListenerInput) -> Result<(), PubsubClientError> { let client = PubsubClient::new(input.ws_url.as_str()).await?; let (mut stream, unsubscribe) = client @@ -63,49 +126,17 @@ async fn run_listener(input: RunListenerInput) -> Result<(), PubsubClientError> .await?; while let Some(update) = stream.next().await { - if find_message_pda(&input.wormhole_pid, update.context.slot).to_string() - != update.value.pubkey - { - continue; // Skip updates that are not for the expected PDA - } - - let unreliable_data: PostedMessageUnreliableData = { - let data = match update.value.account.data.decode() { - Some(data) => data, - None => { - tracing::error!("Failed to decode account data"); - continue; - } - }; - - match BorshDeserialize::deserialize(&mut data.as_slice()) { + let unreliable_data = + match decode_and_verify_update(&input.wormhole_pid, &input.accumulator_address, update) + { Ok(data) => data, - Err(e) => { - tracing::error!(error = ?e, "Invalid unreliable data format"); - continue; - } - } - }; - - if Chain::Pythnet != unreliable_data.emitter_chain.into() { - continue; - } - if input.accumulator_address != Pubkey::from(unreliable_data.emitter_address) { - continue; - } + Err(_) => continue, + }; tokio::spawn({ let api_client = input.api_client.clone(); async move { - let body = Body { - timestamp: unreliable_data.submission_time, - nonce: unreliable_data.nonce, - emitter_chain: unreliable_data.emitter_chain.into(), - emitter_address: Address(unreliable_data.emitter_address), - sequence: unreliable_data.sequence, - consistency_level: unreliable_data.consistency_level, - payload: RawMessage::new(unreliable_data.payload.as_slice()), - }; + let body = message_data_to_body(&unreliable_data); match Observation::try_new(body.clone(), input.secret_key) { Ok(observation) => { if let Err(e) = api_client.post_observation(observation).await { @@ -171,3 +202,192 @@ async fn main() { } } } + +#[cfg(test)] +mod tests { + use super::*; + + use base64::Engine; + use borsh::BorshSerialize; + use solana_account_decoder::{UiAccount, UiAccountData}; + + use crate::posted_message::MessageData; + + fn get_wormhole_pid() -> Pubkey { + Pubkey::from_str("H3fxXJ86ADW2PNuDDmZJg6mzTtPxkYCpNuQUTgmJ7AjU").unwrap() + } + + fn get_accumulator_address() -> Pubkey { + Pubkey::from_str("G9LV2mp9ua1znRAfYwZz5cPiJMAbo1T6mbjdQsDZuMJg").unwrap() + } + + fn get_payload() -> Vec { + vec![ + 65, 85, 87, 86, 0, 0, 0, 0, 0, 13, 74, 15, 90, 0, 0, 39, 16, 172, 145, 156, 108, 253, + 178, 4, 138, 51, 74, 110, 116, 101, 139, 121, 254, 152, 165, 24, 190, + ] + } + + fn get_unreliable_data() -> PostedMessageUnreliableData { + PostedMessageUnreliableData { + message: MessageData { + submission_time: 1749732585, + nonce: 0, + emitter_chain: Chain::Pythnet.into(), + emitter_address: [ + 225, 1, 250, 237, 172, 88, 81, 227, 43, 155, 35, 181, 249, 65, 26, 140, 43, + 172, 74, 174, 62, 212, 221, 123, 129, 29, 209, 167, 46, 164, 170, 113, + ], + sequence: 138184361, + consistency_level: 1, + payload: get_payload(), + vaa_version: 1, + vaa_time: 0, + vaa_signature_account: [0; 32], + }, + } + } + + fn get_update(unreliable_data: PostedMessageUnreliableData) -> Response { + let message = unreliable_data.try_to_vec().unwrap(); + let message = base64::engine::general_purpose::STANDARD.encode(&message); + Response { + context: solana_client::rpc_response::RpcResponseContext { + slot: 123456, + api_version: None, + }, + value: RpcKeyedAccount { + pubkey: find_message_pda(&get_wormhole_pid(), 123456).to_string(), + account: UiAccount { + lamports: 0, + data: UiAccountData::Binary(message, UiAccountEncoding::Base64), + owner: get_accumulator_address().to_string(), + executable: false, + rent_epoch: 0, + space: None, + }, + }, + } + } + + #[test] + fn test_find_message_pda() { + assert_eq!( + find_message_pda(&get_wormhole_pid(), 123456).to_string(), + "Ed9gRoBySmUjSVFxovuhTk6AcFkv9uE8EovvshtHWLNT" + ); + } + + #[test] + fn test_get_body() { + let unreliable_data = get_unreliable_data(); + let body = message_data_to_body(&unreliable_data); + assert_eq!(body.timestamp, unreliable_data.submission_time); + assert_eq!(body.nonce, unreliable_data.nonce); + assert_eq!(body.emitter_chain, Chain::Pythnet); + assert_eq!( + body.emitter_address, + Address(unreliable_data.emitter_address) + ); + assert_eq!(body.sequence, unreliable_data.sequence); + assert_eq!(body.payload, RawMessage::new(get_payload().as_slice())); + } + + #[test] + fn test_decode_and_verify_update() { + let expected_unreliable_data = get_unreliable_data(); + let update = get_update(expected_unreliable_data.clone()); + let result = + decode_and_verify_update(&get_wormhole_pid(), &get_accumulator_address(), update); + + assert!(result.is_ok()); + let unreliable_data = result.unwrap(); + + assert_eq!( + expected_unreliable_data.consistency_level, + unreliable_data.consistency_level + ); + assert_eq!( + expected_unreliable_data.emitter_chain, + unreliable_data.emitter_chain + ); + assert_eq!( + expected_unreliable_data.emitter_address, + unreliable_data.emitter_address + ); + assert_eq!(expected_unreliable_data.sequence, unreliable_data.sequence); + assert_eq!( + expected_unreliable_data.submission_time, + unreliable_data.submission_time + ); + assert_eq!(expected_unreliable_data.nonce, unreliable_data.nonce); + assert_eq!(expected_unreliable_data.payload, unreliable_data.payload); + assert_eq!( + expected_unreliable_data.vaa_version, + unreliable_data.vaa_version + ); + assert_eq!(expected_unreliable_data.vaa_time, unreliable_data.vaa_time); + assert_eq!( + expected_unreliable_data.vaa_signature_account, + unreliable_data.vaa_signature_account + ); + } + + #[test] + fn test_decode_and_verify_update_invalid_pda() { + let mut update = get_update(get_unreliable_data()); + update.context.slot += 1; + let result = + decode_and_verify_update(&get_wormhole_pid(), &get_accumulator_address(), update); + assert_eq!(result.unwrap_err().to_string(), INVALID_PDA_MESSAGE); + } + + #[test] + fn test_decode_and_verify_update_failed_decode() { + let mut update = get_update(get_unreliable_data()); + update.value.account.data = + UiAccountData::Binary("invalid_base64".to_string(), UiAccountEncoding::Base64); + let result = + decode_and_verify_update(&get_wormhole_pid(), &get_accumulator_address(), update); + assert_eq!(result.unwrap_err().to_string(), FAILED_TO_DECODE); + } + + #[test] + fn test_decode_and_verify_update_invalid_unreliable_data() { + let mut update = get_update(get_unreliable_data()); + let message = base64::engine::general_purpose::STANDARD.encode(vec![4, 1, 2, 3, 4]); + update.value.account.data = UiAccountData::Binary(message, UiAccountEncoding::Base64); + let result = + decode_and_verify_update(&get_wormhole_pid(), &get_accumulator_address(), update); + let error_message = format!( + "{}: {}", + INVALID_UNRELIABLE_DATA_FORMAT, + "Magic mismatch. Expected [109, 115, 117] but got [4, 1, 2]" + ); + assert_eq!(result.unwrap_err().to_string(), error_message); + } + + #[test] + fn test_decode_and_verify_update_invalid_emitter_chain() { + let mut unreliable_data = get_unreliable_data(); + unreliable_data.emitter_chain = Chain::Solana.into(); + let result = decode_and_verify_update( + &get_wormhole_pid(), + &get_accumulator_address(), + get_update(unreliable_data), + ); + assert_eq!(result.unwrap_err().to_string(), INVALID_EMITTER_CHAIN); + } + + #[test] + fn test_decode_and_verify_update_invalid_emitter_address() { + let mut unreliable_data = get_unreliable_data(); + unreliable_data.emitter_address = Pubkey::new_unique().to_bytes(); + let result = decode_and_verify_update( + &get_wormhole_pid(), + &get_accumulator_address(), + get_update(unreliable_data), + ); + assert_eq!(result.unwrap_err().to_string(), INVALID_ACCUMULATOR_ADDRESS); + } +} diff --git a/src/posted_message.rs b/src/posted_message.rs index f7e1cec..80777f5 100644 --- a/src/posted_message.rs +++ b/src/posted_message.rs @@ -1,3 +1,12 @@ +//! This module defines the `PostedMessage` structure used to parse and verify messages +//! posted by the Wormhole protocol. +//! +//! ⚠️ Note: This is mostly a copy-paste from the Wormhole reference implementation. +//! If you forget how it works or need updates, refer to the official source: +//! https://github.com/wormhole-foundation/wormhole/blob/main/solana/bridge/program/src/accounts/posted_message.rs# +//! +//! Keep in sync if the upstream changes! + use { borsh::{BorshDeserialize, BorshSerialize}, serde::{Deserialize, Serialize},