|
| 1 | +use crate::events::{Event, EventHandler}; |
| 2 | +use crate::transport::msgs::{LSPSMessage, Prefix, RawLSPSMessage, LSPS_MESSAGE_TYPE}; |
| 3 | +use bitcoin::secp256k1::PublicKey; |
| 4 | +use lightning::ln::peer_handler::CustomMessageHandler; |
| 5 | +use lightning::ln::wire::CustomMessageReader; |
| 6 | +use std::collections::HashMap; |
| 7 | +use std::convert::{TryFrom, TryInto}; |
| 8 | +use std::io; |
| 9 | +use std::sync::{Arc, Mutex}; |
| 10 | + |
| 11 | +pub trait ProtocolMessageHandler { |
| 12 | + type ProtocolMessage: TryFrom<LSPSMessage> + Into<LSPSMessage>; |
| 13 | + |
| 14 | + fn handle_message( |
| 15 | + &self, message: Self::ProtocolMessage, counterparty_node_id: &PublicKey, |
| 16 | + ) -> Result<(), lightning::ln::msgs::LightningError>; |
| 17 | + fn get_and_clear_pending_protocol_messages(&self) -> Vec<(PublicKey, Self::ProtocolMessage)>; |
| 18 | + fn get_and_clear_pending_protocol_events(&self) -> Vec<Event>; |
| 19 | + fn get_protocol_number(&self) -> Option<u16>; |
| 20 | +} |
| 21 | + |
| 22 | +pub trait MessageHandler { |
| 23 | + fn handle_lsps_message( |
| 24 | + &self, message: LSPSMessage, counterparty_node_id: &PublicKey, |
| 25 | + ) -> Result<(), lightning::ln::msgs::LightningError>; |
| 26 | + fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, LSPSMessage)>; |
| 27 | + fn get_and_clear_pending_events(&self) -> Vec<Event>; |
| 28 | + fn get_protocol_number(&self) -> Option<u16>; |
| 29 | +} |
| 30 | + |
| 31 | +impl<T> MessageHandler for T |
| 32 | +where |
| 33 | + T: ProtocolMessageHandler, |
| 34 | + LSPSMessage: TryInto<<T as ProtocolMessageHandler>::ProtocolMessage>, |
| 35 | +{ |
| 36 | + fn handle_lsps_message( |
| 37 | + &self, message: LSPSMessage, counterparty_node_id: &PublicKey, |
| 38 | + ) -> Result<(), lightning::ln::msgs::LightningError> { |
| 39 | + if let Ok(protocol_message) = message.try_into() { |
| 40 | + self.handle_message(protocol_message, counterparty_node_id)?; |
| 41 | + } |
| 42 | + |
| 43 | + Ok(()) |
| 44 | + } |
| 45 | + |
| 46 | + fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, LSPSMessage)> { |
| 47 | + self.get_and_clear_pending_protocol_messages() |
| 48 | + .into_iter() |
| 49 | + .map(|(public_key, protocol_message)| (public_key, protocol_message.into())) |
| 50 | + .collect() |
| 51 | + } |
| 52 | + |
| 53 | + fn get_and_clear_pending_events(&self) -> Vec<Event> { |
| 54 | + self.get_and_clear_pending_protocol_events() |
| 55 | + } |
| 56 | + |
| 57 | + fn get_protocol_number(&self) -> Option<u16> { |
| 58 | + self.get_protocol_number() |
| 59 | + } |
| 60 | +} |
| 61 | + |
| 62 | +pub struct LSPManager { |
| 63 | + pending_messages: Mutex<Vec<(PublicKey, RawLSPSMessage)>>, |
| 64 | + request_id_to_method_map: Mutex<HashMap<String, String>>, |
| 65 | + message_handlers: Arc<Mutex<HashMap<Prefix, Arc<dyn MessageHandler>>>>, |
| 66 | +} |
| 67 | + |
| 68 | +impl LSPManager { |
| 69 | + pub fn new() -> Self { |
| 70 | + Self { |
| 71 | + pending_messages: Mutex::new(Vec::new()), |
| 72 | + request_id_to_method_map: Mutex::new(HashMap::new()), |
| 73 | + message_handlers: Arc::new(Mutex::new(HashMap::new())), |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + pub fn get_message_handlers(&self) -> Arc<Mutex<HashMap<Prefix, Arc<dyn MessageHandler>>>> { |
| 78 | + self.message_handlers.clone() |
| 79 | + } |
| 80 | + |
| 81 | + pub fn register_message_handler( |
| 82 | + &self, prefix: Prefix, message_handler: Arc<dyn MessageHandler>, |
| 83 | + ) { |
| 84 | + self.message_handlers.lock().unwrap().insert(prefix, message_handler); |
| 85 | + } |
| 86 | + |
| 87 | + pub fn process_pending_events<H: EventHandler>(&self, handler: H) { |
| 88 | + let message_handlers = self.message_handlers.lock().unwrap(); |
| 89 | + |
| 90 | + for message_handler in message_handlers.values() { |
| 91 | + let events = message_handler.get_and_clear_pending_events(); |
| 92 | + for event in events { |
| 93 | + handler.handle_event(event); |
| 94 | + } |
| 95 | + } |
| 96 | + } |
| 97 | + |
| 98 | + fn handle_lsps_message( |
| 99 | + &self, msg: LSPSMessage, sender_node_id: &PublicKey, |
| 100 | + ) -> Result<(), lightning::ln::msgs::LightningError> { |
| 101 | + if let Some(prefix) = msg.prefix() { |
| 102 | + let message_handlers = self.message_handlers.lock().unwrap(); |
| 103 | + // TODO: not sure what we are supposed to do when we receive a message we don't have a handler for |
| 104 | + if let Some(message_handler) = message_handlers.get(&prefix) { |
| 105 | + message_handler.handle_lsps_message(msg, sender_node_id)?; |
| 106 | + } |
| 107 | + } |
| 108 | + Ok(()) |
| 109 | + } |
| 110 | + |
| 111 | + fn enqueue_message(&self, node_id: PublicKey, msg: RawLSPSMessage) { |
| 112 | + let mut pending_msgs = self.pending_messages.lock().unwrap(); |
| 113 | + pending_msgs.push((node_id, msg)); |
| 114 | + } |
| 115 | +} |
| 116 | + |
| 117 | +impl CustomMessageReader for LSPManager { |
| 118 | + type CustomMessage = RawLSPSMessage; |
| 119 | + |
| 120 | + fn read<R: io::Read>( |
| 121 | + &self, message_type: u16, buffer: &mut R, |
| 122 | + ) -> Result<Option<Self::CustomMessage>, lightning::ln::msgs::DecodeError> { |
| 123 | + match message_type { |
| 124 | + LSPS_MESSAGE_TYPE => { |
| 125 | + let mut payload = String::new(); |
| 126 | + buffer.read_to_string(&mut payload)?; |
| 127 | + Ok(Some(RawLSPSMessage { payload })) |
| 128 | + } |
| 129 | + _ => Ok(None), |
| 130 | + } |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +impl CustomMessageHandler for LSPManager { |
| 135 | + fn handle_custom_message( |
| 136 | + &self, msg: Self::CustomMessage, sender_node_id: &PublicKey, |
| 137 | + ) -> Result<(), lightning::ln::msgs::LightningError> { |
| 138 | + let mut request_id_to_method_map = self.request_id_to_method_map.lock().unwrap(); |
| 139 | + |
| 140 | + match LSPSMessage::from_str_with_id_map(&msg.payload, &mut request_id_to_method_map) { |
| 141 | + Ok(msg) => self.handle_lsps_message(msg, sender_node_id), |
| 142 | + Err(_) => { |
| 143 | + self.enqueue_message( |
| 144 | + *sender_node_id, |
| 145 | + RawLSPSMessage { |
| 146 | + payload: serde_json::to_string(&LSPSMessage::Invalid).unwrap(), |
| 147 | + }, |
| 148 | + ); |
| 149 | + Ok(()) |
| 150 | + } |
| 151 | + } |
| 152 | + } |
| 153 | + |
| 154 | + fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { |
| 155 | + let mut msgs = vec![]; |
| 156 | + |
| 157 | + { |
| 158 | + let mut pending_messages = self.pending_messages.lock().unwrap(); |
| 159 | + msgs.extend( |
| 160 | + pending_messages.drain(..).collect::<Vec<(PublicKey, Self::CustomMessage)>>(), |
| 161 | + ); |
| 162 | + } |
| 163 | + |
| 164 | + let message_handlers = self.message_handlers.lock().unwrap(); |
| 165 | + for message_handler in message_handlers.values() { |
| 166 | + let protocol_messages = message_handler.get_and_clear_pending_msg(); |
| 167 | + msgs.extend(protocol_messages.into_iter().map(|(node_id, message)| { |
| 168 | + (node_id, RawLSPSMessage { payload: serde_json::to_string(&message).unwrap() }) |
| 169 | + })); |
| 170 | + } |
| 171 | + |
| 172 | + msgs |
| 173 | + } |
| 174 | +} |
0 commit comments