From a9c0c754e18761e10acfd0622a5b4ddbb1c46f17 Mon Sep 17 00:00:00 2001 From: Danial Mehrjerdi Date: Thu, 12 Jun 2025 16:05:00 +0200 Subject: [PATCH 1/2] Add tests for main file --- Cargo.lock | 1 + Cargo.toml | 1 + src/main.rs | 299 ++++++++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 266 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cf4426d..b8c7da4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2920,6 +2920,7 @@ name = "pythnet-watcher" version = "0.1.0" dependencies = [ "anyhow", + "base64 0.22.1", "borsh 0.9.3", "clap", "hex", diff --git a/Cargo.toml b/Cargo.toml index eb03763..284e728 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,4 @@ wormhole-vaas-serde = "0.1.0" [dev-dependencies] serde_json = "1.0.140" +base64 = "0.22.1" diff --git a/src/main.rs b/src/main.rs index c567312..be4af0c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ use { pubsub_client::PubsubClientError, rpc_config::{RpcAccountInfoConfig, RpcProgramAccountsConfig}, rpc_filter::{Memcmp, RpcFilterType}, + rpc_response::{Response, RpcKeyedAccount}, }, solana_sdk::pubkey::Pubkey, std::{fs, str::FromStr, time::Duration}, @@ -40,6 +41,62 @@ fn find_message_pda(wormhole_pid: &Pubkey, slot: u64) -> Pubkey { .0 } +const FAILED_TO_DECODE: &str = "Failed to decode account data"; +const INVALID_UNRELIABLE_DATA_FORMAT: &str = "Invalid unreliable data format"; + +#[derive(Debug)] +enum VerifyUpdateError { + InvalidMessagePDA, + InvalidEmitterChain, + InvalidAccumulatorAddress, + #[allow(dead_code)] + DecodingError(String), +} + +fn decode_and_verify_update( + wormhole_pid: &Pubkey, + accumulator_address: &Pubkey, + update: Response, +) -> Result { + if find_message_pda(wormhole_pid, update.context.slot).to_string() != update.value.pubkey { + return Err(VerifyUpdateError::InvalidMessagePDA); + } + let data = update + .value + .account + .data + .decode() + .ok_or(VerifyUpdateError::DecodingError( + FAILED_TO_DECODE.to_string(), + ))?; + let unreliable_data: PostedMessageUnreliableData = + BorshDeserialize::deserialize(&mut data.as_slice()).map_err(|e| { + VerifyUpdateError::DecodingError(format!("{}: {}", INVALID_UNRELIABLE_DATA_FORMAT, e)) + })?; + + if Chain::Pythnet != unreliable_data.emitter_chain.into() { + return Err(VerifyUpdateError::InvalidEmitterChain); + } + + if accumulator_address != &Pubkey::from(unreliable_data.emitter_address) { + return Err(VerifyUpdateError::InvalidAccumulatorAddress); + } + + Ok(unreliable_data) +} + +fn new_body(unreliable_data: &PostedMessageUnreliableData) -> Body<&RawMessage> { + Body { + timestamp: unreliable_data.submission_time, + nonce: unreliable_data.nonce, + emitter_chain: unreliable_data.emitter_chain.into(), + emitter_address: Address(unreliable_data.emitter_address), + sequence: unreliable_data.sequence, + consistency_level: unreliable_data.consistency_level, + payload: RawMessage::new(unreliable_data.payload.as_slice()), + } +} + async fn run_listener(input: RunListenerInput) -> Result<(), PubsubClientError> { let client = PubsubClient::new(input.ws_url.as_str()).await?; let (mut stream, unsubscribe) = client @@ -63,49 +120,22 @@ async fn run_listener(input: RunListenerInput) -> Result<(), PubsubClientError> .await?; while let Some(update) = stream.next().await { - if find_message_pda(&input.wormhole_pid, update.context.slot).to_string() - != update.value.pubkey - { - continue; // Skip updates that are not for the expected PDA - } - - let unreliable_data: PostedMessageUnreliableData = { - let data = match update.value.account.data.decode() { - Some(data) => data, - None => { - tracing::error!("Failed to decode account data"); - continue; - } - }; - - match BorshDeserialize::deserialize(&mut data.as_slice()) { + let unreliable_data = + match decode_and_verify_update(&input.wormhole_pid, &input.accumulator_address, update) + { Ok(data) => data, Err(e) => { - tracing::error!(error = ?e, "Invalid unreliable data format"); + if !matches!(e, VerifyUpdateError::InvalidMessagePDA) { + tracing::error!(error = ?e, "Received an invalid update"); + } continue; } - } - }; - - if Chain::Pythnet != unreliable_data.emitter_chain.into() { - continue; - } - if input.accumulator_address != Pubkey::from(unreliable_data.emitter_address) { - continue; - } + }; tokio::spawn({ let api_client = input.api_client.clone(); async move { - let body = Body { - timestamp: unreliable_data.submission_time, - nonce: unreliable_data.nonce, - emitter_chain: unreliable_data.emitter_chain.into(), - emitter_address: Address(unreliable_data.emitter_address), - sequence: unreliable_data.sequence, - consistency_level: unreliable_data.consistency_level, - payload: RawMessage::new(unreliable_data.payload.as_slice()), - }; + let body = new_body(&unreliable_data); match Observation::try_new(body.clone(), input.secret_key) { Ok(observation) => { if let Err(e) = api_client.post_observation(observation).await { @@ -171,3 +201,202 @@ async fn main() { } } } + +#[cfg(test)] +mod tests { + use base64::Engine; + use borsh::BorshSerialize; + use solana_account_decoder::{UiAccount, UiAccountData}; + + use super::*; + use crate::posted_message::MessageData; + + fn get_wormhole_pid() -> Pubkey { + Pubkey::from_str("H3fxXJ86ADW2PNuDDmZJg6mzTtPxkYCpNuQUTgmJ7AjU").unwrap() + } + + fn get_accumulator_address() -> Pubkey { + Pubkey::from_str("G9LV2mp9ua1znRAfYwZz5cPiJMAbo1T6mbjdQsDZuMJg").unwrap() + } + + fn get_payload() -> Vec { + vec![ + 65, 85, 87, 86, 0, 0, 0, 0, 0, 13, 74, 15, 90, 0, 0, 39, 16, 172, 145, 156, 108, 253, + 178, 4, 138, 51, 74, 110, 116, 101, 139, 121, 254, 152, 165, 24, 190, + ] + } + + fn get_unreliable_data() -> PostedMessageUnreliableData { + PostedMessageUnreliableData { + message: MessageData { + submission_time: 1749732585, + nonce: 0, + emitter_chain: Chain::Pythnet.into(), + emitter_address: [ + 225, 1, 250, 237, 172, 88, 81, 227, 43, 155, 35, 181, 249, 65, 26, 140, 43, + 172, 74, 174, 62, 212, 221, 123, 129, 29, 209, 167, 46, 164, 170, 113, + ], + sequence: 138184361, + consistency_level: 1, + payload: get_payload(), + vaa_version: 1, + vaa_time: 0, + vaa_signature_account: [0; 32], + }, + } + } + + fn get_update(unreliable_data: PostedMessageUnreliableData) -> Response { + let message = unreliable_data.try_to_vec().unwrap(); + let message = base64::engine::general_purpose::STANDARD.encode(&message); + Response { + context: solana_client::rpc_response::RpcResponseContext { + slot: 123456, + api_version: None, + }, + value: RpcKeyedAccount { + pubkey: find_message_pda(&get_wormhole_pid(), 123456).to_string(), + account: UiAccount { + lamports: 0, + data: UiAccountData::Binary(message, UiAccountEncoding::Base64), + owner: get_accumulator_address().to_string(), + executable: false, + rent_epoch: 0, + space: None, + }, + }, + } + } + + #[test] + fn test_find_message_pda() { + assert_eq!( + find_message_pda(&get_wormhole_pid(), 123456).to_string(), + "Ed9gRoBySmUjSVFxovuhTk6AcFkv9uE8EovvshtHWLNT" + ); + } + + #[test] + fn test_get_body() { + let unreliable_data = get_unreliable_data(); + let body = new_body(&unreliable_data); + assert_eq!(body.timestamp, unreliable_data.submission_time); + assert_eq!(body.nonce, unreliable_data.nonce); + assert_eq!(body.emitter_chain, Chain::Pythnet); + assert_eq!( + body.emitter_address, + Address(unreliable_data.emitter_address) + ); + assert_eq!(body.sequence, unreliable_data.sequence); + assert_eq!(body.payload, RawMessage::new(get_payload().as_slice())); + } + + #[test] + fn test_decode_and_verify_update() { + let expected_unreliable_data = get_unreliable_data(); + let update = get_update(expected_unreliable_data.clone()); + let result = + decode_and_verify_update(&get_wormhole_pid(), &get_accumulator_address(), update); + + assert!(result.is_ok()); + let unreliable_data = result.unwrap(); + + assert_eq!( + expected_unreliable_data.consistency_level, + unreliable_data.consistency_level + ); + assert_eq!( + expected_unreliable_data.emitter_chain, + unreliable_data.emitter_chain + ); + assert_eq!( + expected_unreliable_data.emitter_address, + unreliable_data.emitter_address + ); + assert_eq!(expected_unreliable_data.sequence, unreliable_data.sequence); + assert_eq!( + expected_unreliable_data.submission_time, + unreliable_data.submission_time + ); + assert_eq!(expected_unreliable_data.nonce, unreliable_data.nonce); + assert_eq!(expected_unreliable_data.payload, unreliable_data.payload); + assert_eq!( + expected_unreliable_data.vaa_version, + unreliable_data.vaa_version + ); + assert_eq!(expected_unreliable_data.vaa_time, unreliable_data.vaa_time); + assert_eq!( + expected_unreliable_data.vaa_signature_account, + unreliable_data.vaa_signature_account + ); + } + + #[test] + fn test_decode_and_verify_update_invalid_pda() { + let mut update = get_update(get_unreliable_data()); + update.context.slot += 1; + let result = + decode_and_verify_update(&get_wormhole_pid(), &get_accumulator_address(), update); + assert!(matches!(result, Err(VerifyUpdateError::InvalidMessagePDA))); + } + + #[test] + fn test_decode_and_verify_update_failed_decode() { + let mut update = get_update(get_unreliable_data()); + update.value.account.data = + UiAccountData::Binary("invalid_base64".to_string(), UiAccountEncoding::Base64); + let result = + decode_and_verify_update(&get_wormhole_pid(), &get_accumulator_address(), update); + assert!( + matches!(result, Err(VerifyUpdateError::DecodingError(ref msg)) if msg == FAILED_TO_DECODE), + ); + } + + #[test] + fn test_decode_and_verify_update_invalid_unreliable_data() { + let mut update = get_update(get_unreliable_data()); + let message = base64::engine::general_purpose::STANDARD.encode(vec![4, 1, 2, 3, 4]); + update.value.account.data = UiAccountData::Binary(message, UiAccountEncoding::Base64); + let result = + decode_and_verify_update(&get_wormhole_pid(), &get_accumulator_address(), update); + let error_message = format!( + "{}: {}", + INVALID_UNRELIABLE_DATA_FORMAT, + "Magic mismatch. Expected [109, 115, 117] but got [4, 1, 2]" + ); + assert!( + matches!(result, Err(VerifyUpdateError::DecodingError(ref msg)) + if *msg == error_message) + ); + } + + #[test] + fn test_decode_and_verify_update_invalid_emitter_chain() { + let mut unreliable_data = get_unreliable_data(); + unreliable_data.emitter_chain = Chain::Solana.into(); + let result = decode_and_verify_update( + &get_wormhole_pid(), + &get_accumulator_address(), + get_update(unreliable_data), + ); + assert!(matches!( + result, + Err(VerifyUpdateError::InvalidEmitterChain) + )); + } + + #[test] + fn test_decode_and_verify_update_invalid_emitter_address() { + let mut unreliable_data = get_unreliable_data(); + unreliable_data.emitter_address = Pubkey::new_unique().to_bytes(); + let result = decode_and_verify_update( + &get_wormhole_pid(), + &get_accumulator_address(), + get_update(unreliable_data), + ); + assert!(matches!( + result, + Err(VerifyUpdateError::InvalidAccumulatorAddress) + )); + } +} From 32fb9541270c34477f0054106765c3a7a19c0834 Mon Sep 17 00:00:00 2001 From: Danial Mehrjerdi Date: Fri, 13 Jun 2025 15:58:50 +0200 Subject: [PATCH 2/2] Addressed comments --- src/main.rs | 87 +++++++++++++++++++------------------------ src/posted_message.rs | 9 +++++ 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/src/main.rs b/src/main.rs index be4af0c..95e9793 100644 --- a/src/main.rs +++ b/src/main.rs @@ -43,49 +43,55 @@ fn find_message_pda(wormhole_pid: &Pubkey, slot: u64) -> Pubkey { const FAILED_TO_DECODE: &str = "Failed to decode account data"; const INVALID_UNRELIABLE_DATA_FORMAT: &str = "Invalid unreliable data format"; - -#[derive(Debug)] -enum VerifyUpdateError { - InvalidMessagePDA, - InvalidEmitterChain, - InvalidAccumulatorAddress, - #[allow(dead_code)] - DecodingError(String), -} +const INVALID_PDA_MESSAGE: &str = "Invalid PDA message"; +const INVALID_EMITTER_CHAIN: &str = "Invalid emitter chain"; +const INVALID_ACCUMULATOR_ADDRESS: &str = "Invalid accumulator address"; fn decode_and_verify_update( wormhole_pid: &Pubkey, accumulator_address: &Pubkey, update: Response, -) -> Result { +) -> anyhow::Result { if find_message_pda(wormhole_pid, update.context.slot).to_string() != update.value.pubkey { - return Err(VerifyUpdateError::InvalidMessagePDA); + return Err(anyhow::anyhow!(INVALID_PDA_MESSAGE)); } - let data = update - .value - .account - .data - .decode() - .ok_or(VerifyUpdateError::DecodingError( - FAILED_TO_DECODE.to_string(), - ))?; + let data = update.value.account.data.decode().ok_or_else(|| { + tracing::error!( + data = ?update.value.account.data, + "Failed to decode account data", + ); + anyhow::anyhow!(FAILED_TO_DECODE) + })?; let unreliable_data: PostedMessageUnreliableData = BorshDeserialize::deserialize(&mut data.as_slice()).map_err(|e| { - VerifyUpdateError::DecodingError(format!("{}: {}", INVALID_UNRELIABLE_DATA_FORMAT, e)) + tracing::error!( + data = ?data, + error = ?e, + "Failed to decode unreliable data", + ); + anyhow::anyhow!(format!("{}: {}", INVALID_UNRELIABLE_DATA_FORMAT, e)) })?; if Chain::Pythnet != unreliable_data.emitter_chain.into() { - return Err(VerifyUpdateError::InvalidEmitterChain); + tracing::error!( + emitter_chain = unreliable_data.emitter_chain, + "Invalid emitter chain" + ); + return Err(anyhow::anyhow!(INVALID_EMITTER_CHAIN)); } if accumulator_address != &Pubkey::from(unreliable_data.emitter_address) { - return Err(VerifyUpdateError::InvalidAccumulatorAddress); + tracing::error!( + emitter_address = ?unreliable_data.emitter_address, + "Invalid accumulator address" + ); + return Err(anyhow::anyhow!(INVALID_ACCUMULATOR_ADDRESS)); } Ok(unreliable_data) } -fn new_body(unreliable_data: &PostedMessageUnreliableData) -> Body<&RawMessage> { +fn message_data_to_body(unreliable_data: &PostedMessageUnreliableData) -> Body<&RawMessage> { Body { timestamp: unreliable_data.submission_time, nonce: unreliable_data.nonce, @@ -124,18 +130,13 @@ async fn run_listener(input: RunListenerInput) -> Result<(), PubsubClientError> match decode_and_verify_update(&input.wormhole_pid, &input.accumulator_address, update) { Ok(data) => data, - Err(e) => { - if !matches!(e, VerifyUpdateError::InvalidMessagePDA) { - tracing::error!(error = ?e, "Received an invalid update"); - } - continue; - } + Err(_) => continue, }; tokio::spawn({ let api_client = input.api_client.clone(); async move { - let body = new_body(&unreliable_data); + let body = message_data_to_body(&unreliable_data); match Observation::try_new(body.clone(), input.secret_key) { Ok(observation) => { if let Err(e) = api_client.post_observation(observation).await { @@ -204,11 +205,12 @@ async fn main() { #[cfg(test)] mod tests { + use super::*; + use base64::Engine; use borsh::BorshSerialize; use solana_account_decoder::{UiAccount, UiAccountData}; - use super::*; use crate::posted_message::MessageData; fn get_wormhole_pid() -> Pubkey { @@ -279,7 +281,7 @@ mod tests { #[test] fn test_get_body() { let unreliable_data = get_unreliable_data(); - let body = new_body(&unreliable_data); + let body = message_data_to_body(&unreliable_data); assert_eq!(body.timestamp, unreliable_data.submission_time); assert_eq!(body.nonce, unreliable_data.nonce); assert_eq!(body.emitter_chain, Chain::Pythnet); @@ -337,7 +339,7 @@ mod tests { update.context.slot += 1; let result = decode_and_verify_update(&get_wormhole_pid(), &get_accumulator_address(), update); - assert!(matches!(result, Err(VerifyUpdateError::InvalidMessagePDA))); + assert_eq!(result.unwrap_err().to_string(), INVALID_PDA_MESSAGE); } #[test] @@ -347,9 +349,7 @@ mod tests { UiAccountData::Binary("invalid_base64".to_string(), UiAccountEncoding::Base64); let result = decode_and_verify_update(&get_wormhole_pid(), &get_accumulator_address(), update); - assert!( - matches!(result, Err(VerifyUpdateError::DecodingError(ref msg)) if msg == FAILED_TO_DECODE), - ); + assert_eq!(result.unwrap_err().to_string(), FAILED_TO_DECODE); } #[test] @@ -364,10 +364,7 @@ mod tests { INVALID_UNRELIABLE_DATA_FORMAT, "Magic mismatch. Expected [109, 115, 117] but got [4, 1, 2]" ); - assert!( - matches!(result, Err(VerifyUpdateError::DecodingError(ref msg)) - if *msg == error_message) - ); + assert_eq!(result.unwrap_err().to_string(), error_message); } #[test] @@ -379,10 +376,7 @@ mod tests { &get_accumulator_address(), get_update(unreliable_data), ); - assert!(matches!( - result, - Err(VerifyUpdateError::InvalidEmitterChain) - )); + assert_eq!(result.unwrap_err().to_string(), INVALID_EMITTER_CHAIN); } #[test] @@ -394,9 +388,6 @@ mod tests { &get_accumulator_address(), get_update(unreliable_data), ); - assert!(matches!( - result, - Err(VerifyUpdateError::InvalidAccumulatorAddress) - )); + assert_eq!(result.unwrap_err().to_string(), INVALID_ACCUMULATOR_ADDRESS); } } diff --git a/src/posted_message.rs b/src/posted_message.rs index f7e1cec..80777f5 100644 --- a/src/posted_message.rs +++ b/src/posted_message.rs @@ -1,3 +1,12 @@ +//! This module defines the `PostedMessage` structure used to parse and verify messages +//! posted by the Wormhole protocol. +//! +//! ⚠️ Note: This is mostly a copy-paste from the Wormhole reference implementation. +//! If you forget how it works or need updates, refer to the official source: +//! https://github.com/wormhole-foundation/wormhole/blob/main/solana/bridge/program/src/accounts/posted_message.rs# +//! +//! Keep in sync if the upstream changes! + use { borsh::{BorshDeserialize, BorshSerialize}, serde::{Deserialize, Serialize},