breez_sdk_core/lsps0/
transport.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::node_api::NodeAPI;
6use crate::CustomMessage;
7use anyhow::{anyhow, Result};
8use rand::distributions::Alphanumeric;
9use rand::distributions::DistString;
10use serde::de::DeserializeOwned;
11use tokio::sync::watch;
12use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
13use tokio::time::sleep;
14use tokio_stream::StreamExt;
15
16use super::error::Error;
17use super::jsonrpc::RpcServerMessageBody;
18use super::jsonrpc::{RpcError, RpcRequest, RpcServerMessage};
19
20const LSPS0_MESSAGE_TYPE: u16 = 37913;
21const JSONRPC_VERSION: &str = "2.0";
22
23#[tonic::async_trait]
24trait NotificationSender: Send + Sync {
25    async fn send(&self, params: serde_json::Value) -> Result<(), Error>;
26}
27struct NotificationHandler<TNotification>
28where
29    TNotification: DeserializeOwned + Send,
30{
31    tx: mpsc::Sender<TNotification>,
32}
33
34#[tonic::async_trait]
35impl<TNotification> NotificationSender for NotificationHandler<TNotification>
36where
37    TNotification: DeserializeOwned + Send,
38{
39    async fn send(&self, params: serde_json::Value) -> Result<(), Error> {
40        let n = match serde_json::from_value::<TNotification>(params) {
41            Ok(n) => n,
42            Err(e) => return Err(Error::Deserialization(e)),
43        };
44
45        match self.tx.send(n).await {
46            Ok(_) => Ok(()),
47            Err(_) => Err(Error::Local(anyhow!("receiver dropped"))),
48        }
49    }
50}
51
52struct ResponseOrError {
53    response: Option<serde_json::Value>,
54    error: Option<RpcError>,
55}
56
57/// Transport sends and receives LSPS0 messages to/from remote nodes. One
58/// user node has one Transport.
59pub struct Transport {
60    node: Arc<dyn NodeAPI>,
61    reconnect_interval: Duration,
62    response_handlers: Mutex<HashMap<String, oneshot::Sender<ResponseOrError>>>,
63    notification_handlers: RwLock<HashMap<String, Box<dyn NotificationSender>>>,
64}
65
66impl Transport {
67    #[allow(dead_code)]
68    pub fn new(node: Arc<dyn NodeAPI>) -> Transport {
69        Transport {
70            node,
71            response_handlers: Mutex::new(HashMap::new()),
72            notification_handlers: RwLock::new(HashMap::new()),
73            reconnect_interval: Duration::from_secs(1),
74        }
75    }
76
77    #[allow(dead_code)]
78    pub fn start(self: &Arc<Transport>, cancel: watch::Receiver<()>) {
79        debug!("starting lsps0 transport.");
80        let cloned = self.clone();
81        tokio::spawn(async move {
82            loop {
83                let mut cancel = cancel.clone();
84                if cancel.has_changed().unwrap_or(true) {
85                    return;
86                }
87
88                debug!("lsps0 transport connecting to custom message stream.");
89                let mut stream = match cloned.node.stream_custom_messages().await {
90                    Ok(s) => s,
91                    Err(err) => {
92                        warn!(
93                            "lsps0 transport failed to connect to custom message stream: {}. Retrying in {:?}", 
94                            err,
95                            cloned.reconnect_interval);
96                        break;
97                    }
98                };
99                loop {
100                    tokio::select! {
101                        _ = cancel.changed() => {
102                            debug!("lsps0 tranport cancelled.");
103                            return;
104                        }
105                        msg = stream.next() => {
106                            let msg = match msg {
107                                Some(msg) => match msg {
108                                    Ok(msg) => msg,
109                                    Err(e) => {
110                                        warn!("connection to custom message stream errored: {}", e);
111                                        break;
112                                    }
113                                },
114                                None => {
115                                    warn!("connection to custom message stream dropped");
116                                    break
117                                }
118                            };
119
120                            cloned.handle_message(msg).await;
121                        }
122                    }
123                }
124
125                sleep(cloned.reconnect_interval).await;
126            }
127        });
128    }
129
130    async fn handle_message(&self, msg: CustomMessage) {
131        if msg.message_type != LSPS0_MESSAGE_TYPE {
132            debug!(
133                "received custom message that was not lsps0: node_id={}, type={}, payload={}",
134                hex::encode(&msg.peer_id),
135                msg.message_type,
136                hex::encode(&msg.payload)
137            );
138            return;
139        }
140
141        let v: RpcServerMessage = match serde_json::from_slice(&msg.payload) {
142            Ok(v) => v,
143            Err(e) => {
144                warn!(
145                    "error deserializing lsps0 payload {:?}: {}",
146                    &msg.payload, e
147                );
148                return;
149            }
150        };
151
152        if v.jsonrpc != JSONRPC_VERSION {
153            warn!(
154                "error deserializing lsps0 payload {:?}: Invalid jsonrpc version. Expected {:?}.",
155                &msg.payload, JSONRPC_VERSION
156            );
157            return;
158        }
159
160        match v.body {
161            RpcServerMessageBody::Notification { method, params } => {
162                let id = get_notification_handler_id(&method, msg.peer_id.clone());
163                if let Some(tx) = (*self.notification_handlers.read().await).get(&id) {
164                    match tx.send(params).await {
165                        Ok(_) => (),
166                        Err(e) => match e {
167                            Error::Deserialization(e) => {
168                                // TODO: Drop connection to LSP?
169                                warn!(
170                                    "LSPS0: Got invalid notification {:?} for id {}: {}",
171                                    msg, id, e
172                                );
173                            }
174                            _ => {
175                                debug!("LSPS0: Notification listener dropped for id {}", id);
176                                let mut notification_handlers =
177                                    self.notification_handlers.write().await;
178                                (*notification_handlers).remove(&id);
179                            }
180                        },
181                    }
182                } else {
183                    info!(
184                        "LSPS0: got notification without listener: method: {}, params: {:?}",
185                        method, params
186                    );
187                }
188            }
189            RpcServerMessageBody::Response { id, result } => {
190                let handler_id = get_request_handler_id(&id, msg.peer_id);
191                if let Some(tx) = (*self.response_handlers.lock().await).remove(&handler_id) {
192                    if tx
193                        .send(ResponseOrError {
194                            response: Some(result),
195                            error: None,
196                        })
197                        .is_err()
198                    {
199                        debug!("LSPS0: got response, but listener dropped");
200                    }
201                } else {
202                    debug!(
203                        "LSPS0: got response without listener: id: {}, result: {:?}",
204                        id, result
205                    );
206                }
207            }
208            RpcServerMessageBody::Error { id, error } => {
209                let handler_id = get_request_handler_id(&id, msg.peer_id);
210                if let Some(tx) = (*self.response_handlers.lock().await).remove(&handler_id) {
211                    if tx
212                        .send(ResponseOrError {
213                            response: None,
214                            error: Some(error),
215                        })
216                        .is_err()
217                    {
218                        debug!("LSPS0: got error response, but listener dropped");
219                    }
220                } else {
221                    debug!(
222                        "LSPS0: got error without listener: id: {}, error: {:?}",
223                        id, error
224                    );
225                }
226            }
227        }
228    }
229
230    pub async fn request_response<TRequest, TResponse>(
231        &self,
232        method: String,
233        peer_id: Vec<u8>,
234        req: &TRequest,
235        timeout: Duration,
236    ) -> Result<TResponse, Error>
237    where
238        TRequest: serde::Serialize,
239        TResponse: serde::de::DeserializeOwned,
240    {
241        let request_id = generate_request_id();
242        let wrapped_req = RpcRequest {
243            id: request_id.clone(),
244            jsonrpc: String::from(JSONRPC_VERSION),
245            method,
246            params: req,
247        };
248        let payload = serde_json::to_vec(&wrapped_req)?;
249        let (tx, rx) = oneshot::channel();
250        let handler_id = get_request_handler_id(&request_id, peer_id.clone());
251        (*self.response_handlers.lock().await).insert(handler_id.clone(), tx);
252
253        if let Err(e) = self
254            .node
255            .send_custom_message(CustomMessage {
256                peer_id,
257                message_type: LSPS0_MESSAGE_TYPE,
258                payload,
259            })
260            .await
261        {
262            (*self.response_handlers.lock().await).remove(&handler_id);
263            return Err(e.into());
264        }
265
266        let result_or = tokio::time::timeout(timeout, rx).await?;
267        let response_or_error = result_or?;
268        if let Some(response) = response_or_error.response {
269            let resp = serde_json::from_value::<TResponse>(response)?;
270            Ok(resp)
271        } else if let Some(error) = response_or_error.error {
272            Err(Error::Remote(error))
273        } else {
274            Err(Error::Local(anyhow!("did not get response or error")))
275        }
276    }
277
278    pub async fn stream_notifications<TNotification>(
279        &self,
280        method: String,
281        node_id: Vec<u8>,
282    ) -> Result<mpsc::Receiver<TNotification>, Error>
283    where
284        TNotification: serde::de::DeserializeOwned + std::marker::Send + 'static,
285    {
286        let (tx, rx) = mpsc::channel(100);
287        let id = get_notification_handler_id(&method, node_id);
288        (*self.notification_handlers.write().await)
289            .insert(id, Box::new(NotificationHandler::<TNotification> { tx }));
290
291        Ok(rx)
292    }
293}
294
295fn generate_request_id() -> String {
296    Alphanumeric.sample_string(&mut rand::thread_rng(), 21)
297}
298
299fn get_request_handler_id(request_id: &str, node_id: Vec<u8>) -> String {
300    let mut id = hex::encode(node_id);
301    id.push('|');
302    id.push_str(request_id);
303    id
304}
305
306fn get_notification_handler_id(method: &str, node_id: Vec<u8>) -> String {
307    let mut id = hex::encode(node_id);
308    id.push('|');
309    id.push_str(method);
310    id
311}
312
313#[cfg(test)]
314mod tests {
315    use serde::{Deserialize, Serialize};
316    use serde_json::json;
317    use std::{sync::Arc, time::Duration};
318    use tokio::sync::{mpsc, watch};
319
320    use crate::{
321        breez_services::tests::get_dummy_node_state,
322        lsps0::{
323            error::Error,
324            jsonrpc::{RpcError, RpcRequest, RpcServerMessage, RpcServerMessageBody},
325            transport::LSPS0_MESSAGE_TYPE,
326        },
327        test_utils::MockNodeAPI,
328        CustomMessage,
329    };
330
331    use super::Transport;
332
333    #[derive(Serialize, Deserialize)]
334    pub struct Request {}
335
336    #[derive(Serialize, Deserialize)]
337    pub struct Response {}
338
339    #[derive(Serialize, Deserialize)]
340    pub struct Notification {}
341
342    #[tokio::test]
343    async fn test_request_response_success() {
344        let peer_id = vec![21];
345        let peer_id_clone = peer_id.clone();
346        let (tx, rx) = mpsc::channel(1);
347        let tx_arc = Arc::new(tx);
348        let on_send_request = move |message: CustomMessage| {
349            let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
350            let resp = RpcServerMessage {
351                jsonrpc: req.jsonrpc,
352                body: RpcServerMessageBody::Response {
353                    id: req.id,
354                    result: json!({}),
355                },
356            };
357            let raw_resp = serde_json::to_vec(&resp).unwrap();
358            let tx_arc = tx_arc.clone();
359            let peer_id = peer_id_clone.clone();
360            tokio::spawn(async move {
361                tx_arc
362                    .send(CustomMessage {
363                        message_type: LSPS0_MESSAGE_TYPE,
364                        payload: raw_resp,
365                        peer_id,
366                    })
367                    .await
368                    .unwrap();
369            });
370            Ok(())
371        };
372
373        let mut node_api = MockNodeAPI::new(get_dummy_node_state());
374        node_api.set_on_send_custom_message(Box::new(on_send_request));
375        node_api.set_on_stream_custom_messages(rx).await;
376
377        let transport = Arc::new(Transport::new(Arc::new(node_api)));
378        let (stop, cancel) = watch::channel(());
379        transport.start(cancel);
380        let timeout = Duration::from_millis(10);
381        transport
382            .request_response::<Request, Response>(
383                String::from("test"),
384                peer_id.clone(),
385                &Request {},
386                timeout,
387            )
388            .await
389            .unwrap();
390        let _ = stop.send(());
391    }
392
393    #[tokio::test]
394    async fn test_request_response_timeout() {
395        let peer_id = vec![21];
396        let node_api = MockNodeAPI::new(get_dummy_node_state());
397        let transport = Arc::new(Transport::new(Arc::new(node_api)));
398        let (stop, cancel) = watch::channel(());
399        transport.start(cancel);
400        let timeout = Duration::from_millis(10);
401        let result = transport
402            .request_response::<Request, Response>(
403                String::from("test"),
404                peer_id.clone(),
405                &Request {},
406                timeout,
407            )
408            .await;
409
410        assert!(matches!(result.err().unwrap(), Error::Timeout));
411        let _ = stop.send(());
412    }
413
414    #[tokio::test]
415    async fn test_response_from_different_node() {
416        let (tx, rx) = mpsc::channel(1);
417        let tx_arc = Arc::new(tx);
418        let on_send_request = move |message: CustomMessage| {
419            let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
420            let resp = RpcServerMessage {
421                jsonrpc: req.jsonrpc,
422                body: RpcServerMessageBody::Response {
423                    id: req.id,
424                    result: json!({}),
425                },
426            };
427            let raw_resp = serde_json::to_vec(&resp).unwrap();
428            let tx_arc = tx_arc.clone();
429            tokio::spawn(async move {
430                tx_arc
431                    .send(CustomMessage {
432                        message_type: LSPS0_MESSAGE_TYPE,
433                        payload: raw_resp,
434                        peer_id: vec![22],
435                    })
436                    .await
437                    .unwrap();
438            });
439            Ok(())
440        };
441
442        let mut node_api = MockNodeAPI::new(get_dummy_node_state());
443        node_api.set_on_send_custom_message(Box::new(on_send_request));
444        node_api.set_on_stream_custom_messages(rx).await;
445
446        let transport = Arc::new(Transport::new(Arc::new(node_api)));
447        let (stop, cancel) = watch::channel(());
448        transport.start(cancel);
449        let timeout = Duration::from_millis(10);
450        let result = transport
451            .request_response::<Request, Response>(
452                String::from("test"),
453                vec![21],
454                &Request {},
455                timeout,
456            )
457            .await;
458
459        assert!(matches!(result.err().unwrap(), Error::Timeout));
460        let _ = stop.send(());
461    }
462
463    #[tokio::test]
464    async fn test_request_response_remote_error() {
465        let peer_id = vec![21];
466        let peer_id_clone = peer_id.clone();
467        let (tx, rx) = mpsc::channel(1);
468        let tx_arc = Arc::new(tx);
469        let on_send_request = move |message: CustomMessage| {
470            let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
471            let resp = RpcServerMessage {
472                jsonrpc: req.jsonrpc,
473                body: RpcServerMessageBody::Error {
474                    id: req.id,
475                    error: RpcError {
476                        code: 1,
477                        data: Some(json!({})),
478                        message: String::from("error occurred"),
479                    },
480                },
481            };
482            let raw_resp = serde_json::to_vec(&resp).unwrap();
483            let tx_arc = tx_arc.clone();
484            let peer_id = peer_id_clone.clone();
485            tokio::spawn(async move {
486                tx_arc
487                    .send(CustomMessage {
488                        message_type: LSPS0_MESSAGE_TYPE,
489                        payload: raw_resp,
490                        peer_id,
491                    })
492                    .await
493                    .unwrap();
494            });
495            Ok(())
496        };
497
498        let mut node_api = MockNodeAPI::new(get_dummy_node_state());
499        node_api.set_on_send_custom_message(Box::new(on_send_request));
500        node_api.set_on_stream_custom_messages(rx).await;
501
502        let transport = Arc::new(Transport::new(Arc::new(node_api)));
503        let (stop, cancel) = watch::channel(());
504        transport.start(cancel);
505        let timeout = Duration::from_millis(10);
506        let result = transport
507            .request_response::<Request, Response>(
508                String::from("test"),
509                peer_id.clone(),
510                &Request {},
511                timeout,
512            )
513            .await;
514
515        let err = result.err().unwrap();
516        assert!(matches!(err, Error::Remote { .. }));
517        match err {
518            Error::Remote(e) => {
519                assert_eq!(e.code, 1);
520                assert_eq!(e.message, String::from("error occurred"));
521                assert_eq!(e.data, Some(json!({})));
522            }
523            _ => unreachable!(),
524        };
525        let _ = stop.send(());
526    }
527
528    #[tokio::test]
529    async fn test_request_response_deserialization_error() {
530        let peer_id = vec![21];
531        let peer_id_clone = peer_id.clone();
532        let (tx, rx) = mpsc::channel(1);
533        let tx_arc = Arc::new(tx);
534        let on_send_request = move |message: CustomMessage| {
535            let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
536            let resp = RpcServerMessage {
537                jsonrpc: req.jsonrpc,
538                body: RpcServerMessageBody::Response {
539                    id: req.id,
540                    result: json!("cannot deserialize this"),
541                },
542            };
543            let raw_resp = serde_json::to_vec(&resp).unwrap();
544            let tx_arc = tx_arc.clone();
545            let peer_id = peer_id_clone.clone();
546            tokio::spawn(async move {
547                tx_arc
548                    .send(CustomMessage {
549                        message_type: LSPS0_MESSAGE_TYPE,
550                        payload: raw_resp,
551                        peer_id,
552                    })
553                    .await
554                    .unwrap();
555            });
556            Ok(())
557        };
558
559        let mut node_api = MockNodeAPI::new(get_dummy_node_state());
560        node_api.set_on_send_custom_message(Box::new(on_send_request));
561        node_api.set_on_stream_custom_messages(rx).await;
562
563        let transport = Arc::new(Transport::new(Arc::new(node_api)));
564        let (stop, cancel) = watch::channel(());
565        transport.start(cancel);
566        let timeout = Duration::from_millis(10);
567        let result = transport
568            .request_response::<Request, Response>(
569                String::from("test"),
570                peer_id.clone(),
571                &Request {},
572                timeout,
573            )
574            .await;
575
576        let err = result.err().unwrap();
577        assert!(matches!(err, Error::Deserialization { .. }));
578        let _ = stop.send(());
579    }
580
581    #[tokio::test]
582    async fn test_request_response_different_id() {
583        let peer_id = vec![21];
584        let peer_id_clone = peer_id.clone();
585        let (tx, rx) = mpsc::channel(1);
586        let tx_arc = Arc::new(tx);
587        let on_send_request = move |message: CustomMessage| {
588            let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
589            let resp = RpcServerMessage {
590                jsonrpc: req.jsonrpc,
591                body: RpcServerMessageBody::Response {
592                    id: String::from("different id"),
593                    result: json!({}),
594                },
595            };
596            let raw_resp = serde_json::to_vec(&resp).unwrap();
597            let tx_arc = tx_arc.clone();
598            let peer_id = peer_id_clone.clone();
599            tokio::spawn(async move {
600                tx_arc
601                    .send(CustomMessage {
602                        message_type: LSPS0_MESSAGE_TYPE,
603                        payload: raw_resp,
604                        peer_id,
605                    })
606                    .await
607                    .unwrap();
608            });
609            Ok(())
610        };
611
612        let mut node_api = MockNodeAPI::new(get_dummy_node_state());
613        node_api.set_on_send_custom_message(Box::new(on_send_request));
614        node_api.set_on_stream_custom_messages(rx).await;
615
616        let transport = Arc::new(Transport::new(Arc::new(node_api)));
617        let (stop, cancel) = watch::channel(());
618        transport.start(cancel);
619        let timeout = Duration::from_millis(10);
620        let result = transport
621            .request_response::<Request, Response>(
622                String::from("test"),
623                peer_id.clone(),
624                &Request {},
625                timeout,
626            )
627            .await;
628        assert!(matches!(result.err().unwrap(), Error::Timeout));
629        let _ = stop.send(());
630    }
631
632    #[tokio::test]
633    async fn test_request_response_not_lsps0() {
634        let peer_id = vec![21];
635        let peer_id_clone = peer_id.clone();
636        let (tx, rx) = mpsc::channel(1);
637        let tx_arc = Arc::new(tx);
638        let on_send_request = move |message: CustomMessage| {
639            let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
640            let resp = RpcServerMessage {
641                jsonrpc: req.jsonrpc,
642                body: RpcServerMessageBody::Response {
643                    id: req.id,
644                    result: json!({}),
645                },
646            };
647            let raw_resp = serde_json::to_vec(&resp).unwrap();
648            let tx_arc = tx_arc.clone();
649            let peer_id = peer_id_clone.clone();
650            tokio::spawn(async move {
651                tx_arc
652                    .send(CustomMessage {
653                        message_type: LSPS0_MESSAGE_TYPE + 1,
654                        payload: raw_resp,
655                        peer_id,
656                    })
657                    .await
658                    .unwrap();
659            });
660            Ok(())
661        };
662
663        let mut node_api = MockNodeAPI::new(get_dummy_node_state());
664        node_api.set_on_send_custom_message(Box::new(on_send_request));
665        node_api.set_on_stream_custom_messages(rx).await;
666
667        let transport = Arc::new(Transport::new(Arc::new(node_api)));
668        let (stop, cancel) = watch::channel(());
669        transport.start(cancel);
670        let timeout = Duration::from_millis(10);
671        let result = transport
672            .request_response::<Request, Response>(
673                String::from("test"),
674                peer_id.clone(),
675                &Request {},
676                timeout,
677            )
678            .await;
679        assert!(matches!(result.err().unwrap(), Error::Timeout));
680        let _ = stop.send(());
681    }
682
683    #[tokio::test]
684    async fn test_notification_success() {
685        let method = String::from("test");
686        let peer_id = vec![21];
687        let (tx, rx) = mpsc::channel(1);
688        let mut node_api = MockNodeAPI::new(get_dummy_node_state());
689        node_api.set_on_stream_custom_messages(rx).await;
690
691        let transport = Arc::new(Transport::new(Arc::new(node_api)));
692        let (stop, cancel) = watch::channel(());
693        transport.start(cancel);
694        let mut stream = transport
695            .stream_notifications::<Notification>(method, peer_id.clone())
696            .await
697            .unwrap();
698        let payload = RpcServerMessage {
699            jsonrpc: String::from("2.0"),
700            body: RpcServerMessageBody::Notification {
701                method: String::from("test"),
702                params: json!({}),
703            },
704        };
705        let raw_payload = serde_json::to_vec(&payload).unwrap();
706        tx.send(CustomMessage {
707            message_type: LSPS0_MESSAGE_TYPE,
708            payload: raw_payload,
709            peer_id: peer_id.clone(),
710        })
711        .await
712        .unwrap();
713        stream.recv().await.unwrap();
714        let _ = stop.send(());
715    }
716
717    #[tokio::test]
718    async fn test_notification_different_node() {
719        let method = String::from("test");
720        let peer_id = vec![21];
721        let (tx, rx) = mpsc::channel(1);
722        let mut node_api = MockNodeAPI::new(get_dummy_node_state());
723        node_api.set_on_stream_custom_messages(rx).await;
724
725        let transport = Arc::new(Transport::new(Arc::new(node_api)));
726        let (stop, cancel) = watch::channel(());
727        transport.start(cancel);
728        let mut stream = transport
729            .stream_notifications::<Notification>(method, peer_id.clone())
730            .await
731            .unwrap();
732        let payload = RpcServerMessage {
733            jsonrpc: String::from("2.0"),
734            body: RpcServerMessageBody::Notification {
735                method: String::from("test"),
736                params: json!({}),
737            },
738        };
739        let raw_payload = serde_json::to_vec(&payload).unwrap();
740        tx.send(CustomMessage {
741            message_type: LSPS0_MESSAGE_TYPE,
742            payload: raw_payload,
743            peer_id: vec![22],
744        })
745        .await
746        .unwrap();
747        let a = stream.try_recv();
748        assert!(a.is_err());
749        let _ = stop.send(());
750    }
751}