1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::node_api::NodeAPI;
6use crate::CustomMessage;
7use anyhow::{anyhow, Result};
8use rand::distributions::Alphanumeric;
9use rand::distributions::DistString;
10use serde::de::DeserializeOwned;
11use tokio::sync::watch;
12use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
13use tokio::time::sleep;
14use tokio_stream::StreamExt;
15
16use super::error::Error;
17use super::jsonrpc::RpcServerMessageBody;
18use super::jsonrpc::{RpcError, RpcRequest, RpcServerMessage};
19
20const LSPS0_MESSAGE_TYPE: u16 = 37913;
21const JSONRPC_VERSION: &str = "2.0";
22
23#[tonic::async_trait]
24trait NotificationSender: Send + Sync {
25 async fn send(&self, params: serde_json::Value) -> Result<(), Error>;
26}
27struct NotificationHandler<TNotification>
28where
29 TNotification: DeserializeOwned + Send,
30{
31 tx: mpsc::Sender<TNotification>,
32}
33
34#[tonic::async_trait]
35impl<TNotification> NotificationSender for NotificationHandler<TNotification>
36where
37 TNotification: DeserializeOwned + Send,
38{
39 async fn send(&self, params: serde_json::Value) -> Result<(), Error> {
40 let n = match serde_json::from_value::<TNotification>(params) {
41 Ok(n) => n,
42 Err(e) => return Err(Error::Deserialization(e)),
43 };
44
45 match self.tx.send(n).await {
46 Ok(_) => Ok(()),
47 Err(_) => Err(Error::Local(anyhow!("receiver dropped"))),
48 }
49 }
50}
51
52struct ResponseOrError {
53 response: Option<serde_json::Value>,
54 error: Option<RpcError>,
55}
56
57pub struct Transport {
60 node: Arc<dyn NodeAPI>,
61 reconnect_interval: Duration,
62 response_handlers: Mutex<HashMap<String, oneshot::Sender<ResponseOrError>>>,
63 notification_handlers: RwLock<HashMap<String, Box<dyn NotificationSender>>>,
64}
65
66impl Transport {
67 #[allow(dead_code)]
68 pub fn new(node: Arc<dyn NodeAPI>) -> Transport {
69 Transport {
70 node,
71 response_handlers: Mutex::new(HashMap::new()),
72 notification_handlers: RwLock::new(HashMap::new()),
73 reconnect_interval: Duration::from_secs(1),
74 }
75 }
76
77 #[allow(dead_code)]
78 pub fn start(self: &Arc<Transport>, cancel: watch::Receiver<()>) {
79 debug!("starting lsps0 transport.");
80 let cloned = self.clone();
81 tokio::spawn(async move {
82 loop {
83 let mut cancel = cancel.clone();
84 if cancel.has_changed().unwrap_or(true) {
85 return;
86 }
87
88 debug!("lsps0 transport connecting to custom message stream.");
89 let mut stream = match cloned.node.stream_custom_messages().await {
90 Ok(s) => s,
91 Err(err) => {
92 warn!(
93 "lsps0 transport failed to connect to custom message stream: {}. Retrying in {:?}",
94 err,
95 cloned.reconnect_interval);
96 break;
97 }
98 };
99 loop {
100 tokio::select! {
101 _ = cancel.changed() => {
102 debug!("lsps0 tranport cancelled.");
103 return;
104 }
105 msg = stream.next() => {
106 let msg = match msg {
107 Some(msg) => match msg {
108 Ok(msg) => msg,
109 Err(e) => {
110 warn!("connection to custom message stream errored: {}", e);
111 break;
112 }
113 },
114 None => {
115 warn!("connection to custom message stream dropped");
116 break
117 }
118 };
119
120 cloned.handle_message(msg).await;
121 }
122 }
123 }
124
125 sleep(cloned.reconnect_interval).await;
126 }
127 });
128 }
129
130 async fn handle_message(&self, msg: CustomMessage) {
131 if msg.message_type != LSPS0_MESSAGE_TYPE {
132 debug!(
133 "received custom message that was not lsps0: node_id={}, type={}, payload={}",
134 hex::encode(&msg.peer_id),
135 msg.message_type,
136 hex::encode(&msg.payload)
137 );
138 return;
139 }
140
141 let v: RpcServerMessage = match serde_json::from_slice(&msg.payload) {
142 Ok(v) => v,
143 Err(e) => {
144 warn!(
145 "error deserializing lsps0 payload {:?}: {}",
146 &msg.payload, e
147 );
148 return;
149 }
150 };
151
152 if v.jsonrpc != JSONRPC_VERSION {
153 warn!(
154 "error deserializing lsps0 payload {:?}: Invalid jsonrpc version. Expected {:?}.",
155 &msg.payload, JSONRPC_VERSION
156 );
157 return;
158 }
159
160 match v.body {
161 RpcServerMessageBody::Notification { method, params } => {
162 let id = get_notification_handler_id(&method, msg.peer_id.clone());
163 if let Some(tx) = (*self.notification_handlers.read().await).get(&id) {
164 match tx.send(params).await {
165 Ok(_) => (),
166 Err(e) => match e {
167 Error::Deserialization(e) => {
168 warn!(
170 "LSPS0: Got invalid notification {:?} for id {}: {}",
171 msg, id, e
172 );
173 }
174 _ => {
175 debug!("LSPS0: Notification listener dropped for id {}", id);
176 let mut notification_handlers =
177 self.notification_handlers.write().await;
178 (*notification_handlers).remove(&id);
179 }
180 },
181 }
182 } else {
183 info!(
184 "LSPS0: got notification without listener: method: {}, params: {:?}",
185 method, params
186 );
187 }
188 }
189 RpcServerMessageBody::Response { id, result } => {
190 let handler_id = get_request_handler_id(&id, msg.peer_id);
191 if let Some(tx) = (*self.response_handlers.lock().await).remove(&handler_id) {
192 if tx
193 .send(ResponseOrError {
194 response: Some(result),
195 error: None,
196 })
197 .is_err()
198 {
199 debug!("LSPS0: got response, but listener dropped");
200 }
201 } else {
202 debug!(
203 "LSPS0: got response without listener: id: {}, result: {:?}",
204 id, result
205 );
206 }
207 }
208 RpcServerMessageBody::Error { id, error } => {
209 let handler_id = get_request_handler_id(&id, msg.peer_id);
210 if let Some(tx) = (*self.response_handlers.lock().await).remove(&handler_id) {
211 if tx
212 .send(ResponseOrError {
213 response: None,
214 error: Some(error),
215 })
216 .is_err()
217 {
218 debug!("LSPS0: got error response, but listener dropped");
219 }
220 } else {
221 debug!(
222 "LSPS0: got error without listener: id: {}, error: {:?}",
223 id, error
224 );
225 }
226 }
227 }
228 }
229
230 pub async fn request_response<TRequest, TResponse>(
231 &self,
232 method: String,
233 peer_id: Vec<u8>,
234 req: &TRequest,
235 timeout: Duration,
236 ) -> Result<TResponse, Error>
237 where
238 TRequest: serde::Serialize,
239 TResponse: serde::de::DeserializeOwned,
240 {
241 let request_id = generate_request_id();
242 let wrapped_req = RpcRequest {
243 id: request_id.clone(),
244 jsonrpc: String::from(JSONRPC_VERSION),
245 method,
246 params: req,
247 };
248 let payload = serde_json::to_vec(&wrapped_req)?;
249 let (tx, rx) = oneshot::channel();
250 let handler_id = get_request_handler_id(&request_id, peer_id.clone());
251 (*self.response_handlers.lock().await).insert(handler_id.clone(), tx);
252
253 if let Err(e) = self
254 .node
255 .send_custom_message(CustomMessage {
256 peer_id,
257 message_type: LSPS0_MESSAGE_TYPE,
258 payload,
259 })
260 .await
261 {
262 (*self.response_handlers.lock().await).remove(&handler_id);
263 return Err(e.into());
264 }
265
266 let result_or = tokio::time::timeout(timeout, rx).await?;
267 let response_or_error = result_or?;
268 if let Some(response) = response_or_error.response {
269 let resp = serde_json::from_value::<TResponse>(response)?;
270 Ok(resp)
271 } else if let Some(error) = response_or_error.error {
272 Err(Error::Remote(error))
273 } else {
274 Err(Error::Local(anyhow!("did not get response or error")))
275 }
276 }
277
278 pub async fn stream_notifications<TNotification>(
279 &self,
280 method: String,
281 node_id: Vec<u8>,
282 ) -> Result<mpsc::Receiver<TNotification>, Error>
283 where
284 TNotification: serde::de::DeserializeOwned + std::marker::Send + 'static,
285 {
286 let (tx, rx) = mpsc::channel(100);
287 let id = get_notification_handler_id(&method, node_id);
288 (*self.notification_handlers.write().await)
289 .insert(id, Box::new(NotificationHandler::<TNotification> { tx }));
290
291 Ok(rx)
292 }
293}
294
295fn generate_request_id() -> String {
296 Alphanumeric.sample_string(&mut rand::thread_rng(), 21)
297}
298
299fn get_request_handler_id(request_id: &str, node_id: Vec<u8>) -> String {
300 let mut id = hex::encode(node_id);
301 id.push('|');
302 id.push_str(request_id);
303 id
304}
305
306fn get_notification_handler_id(method: &str, node_id: Vec<u8>) -> String {
307 let mut id = hex::encode(node_id);
308 id.push('|');
309 id.push_str(method);
310 id
311}
312
313#[cfg(test)]
314mod tests {
315 use serde::{Deserialize, Serialize};
316 use serde_json::json;
317 use std::{sync::Arc, time::Duration};
318 use tokio::sync::{mpsc, watch};
319
320 use crate::{
321 breez_services::tests::get_dummy_node_state,
322 lsps0::{
323 error::Error,
324 jsonrpc::{RpcError, RpcRequest, RpcServerMessage, RpcServerMessageBody},
325 transport::LSPS0_MESSAGE_TYPE,
326 },
327 test_utils::MockNodeAPI,
328 CustomMessage,
329 };
330
331 use super::Transport;
332
333 #[derive(Serialize, Deserialize)]
334 pub struct Request {}
335
336 #[derive(Serialize, Deserialize)]
337 pub struct Response {}
338
339 #[derive(Serialize, Deserialize)]
340 pub struct Notification {}
341
342 #[tokio::test]
343 async fn test_request_response_success() {
344 let peer_id = vec![21];
345 let peer_id_clone = peer_id.clone();
346 let (tx, rx) = mpsc::channel(1);
347 let tx_arc = Arc::new(tx);
348 let on_send_request = move |message: CustomMessage| {
349 let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
350 let resp = RpcServerMessage {
351 jsonrpc: req.jsonrpc,
352 body: RpcServerMessageBody::Response {
353 id: req.id,
354 result: json!({}),
355 },
356 };
357 let raw_resp = serde_json::to_vec(&resp).unwrap();
358 let tx_arc = tx_arc.clone();
359 let peer_id = peer_id_clone.clone();
360 tokio::spawn(async move {
361 tx_arc
362 .send(CustomMessage {
363 message_type: LSPS0_MESSAGE_TYPE,
364 payload: raw_resp,
365 peer_id,
366 })
367 .await
368 .unwrap();
369 });
370 Ok(())
371 };
372
373 let mut node_api = MockNodeAPI::new(get_dummy_node_state());
374 node_api.set_on_send_custom_message(Box::new(on_send_request));
375 node_api.set_on_stream_custom_messages(rx).await;
376
377 let transport = Arc::new(Transport::new(Arc::new(node_api)));
378 let (stop, cancel) = watch::channel(());
379 transport.start(cancel);
380 let timeout = Duration::from_millis(10);
381 transport
382 .request_response::<Request, Response>(
383 String::from("test"),
384 peer_id.clone(),
385 &Request {},
386 timeout,
387 )
388 .await
389 .unwrap();
390 let _ = stop.send(());
391 }
392
393 #[tokio::test]
394 async fn test_request_response_timeout() {
395 let peer_id = vec![21];
396 let node_api = MockNodeAPI::new(get_dummy_node_state());
397 let transport = Arc::new(Transport::new(Arc::new(node_api)));
398 let (stop, cancel) = watch::channel(());
399 transport.start(cancel);
400 let timeout = Duration::from_millis(10);
401 let result = transport
402 .request_response::<Request, Response>(
403 String::from("test"),
404 peer_id.clone(),
405 &Request {},
406 timeout,
407 )
408 .await;
409
410 assert!(matches!(result.err().unwrap(), Error::Timeout));
411 let _ = stop.send(());
412 }
413
414 #[tokio::test]
415 async fn test_response_from_different_node() {
416 let (tx, rx) = mpsc::channel(1);
417 let tx_arc = Arc::new(tx);
418 let on_send_request = move |message: CustomMessage| {
419 let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
420 let resp = RpcServerMessage {
421 jsonrpc: req.jsonrpc,
422 body: RpcServerMessageBody::Response {
423 id: req.id,
424 result: json!({}),
425 },
426 };
427 let raw_resp = serde_json::to_vec(&resp).unwrap();
428 let tx_arc = tx_arc.clone();
429 tokio::spawn(async move {
430 tx_arc
431 .send(CustomMessage {
432 message_type: LSPS0_MESSAGE_TYPE,
433 payload: raw_resp,
434 peer_id: vec![22],
435 })
436 .await
437 .unwrap();
438 });
439 Ok(())
440 };
441
442 let mut node_api = MockNodeAPI::new(get_dummy_node_state());
443 node_api.set_on_send_custom_message(Box::new(on_send_request));
444 node_api.set_on_stream_custom_messages(rx).await;
445
446 let transport = Arc::new(Transport::new(Arc::new(node_api)));
447 let (stop, cancel) = watch::channel(());
448 transport.start(cancel);
449 let timeout = Duration::from_millis(10);
450 let result = transport
451 .request_response::<Request, Response>(
452 String::from("test"),
453 vec![21],
454 &Request {},
455 timeout,
456 )
457 .await;
458
459 assert!(matches!(result.err().unwrap(), Error::Timeout));
460 let _ = stop.send(());
461 }
462
463 #[tokio::test]
464 async fn test_request_response_remote_error() {
465 let peer_id = vec![21];
466 let peer_id_clone = peer_id.clone();
467 let (tx, rx) = mpsc::channel(1);
468 let tx_arc = Arc::new(tx);
469 let on_send_request = move |message: CustomMessage| {
470 let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
471 let resp = RpcServerMessage {
472 jsonrpc: req.jsonrpc,
473 body: RpcServerMessageBody::Error {
474 id: req.id,
475 error: RpcError {
476 code: 1,
477 data: Some(json!({})),
478 message: String::from("error occurred"),
479 },
480 },
481 };
482 let raw_resp = serde_json::to_vec(&resp).unwrap();
483 let tx_arc = tx_arc.clone();
484 let peer_id = peer_id_clone.clone();
485 tokio::spawn(async move {
486 tx_arc
487 .send(CustomMessage {
488 message_type: LSPS0_MESSAGE_TYPE,
489 payload: raw_resp,
490 peer_id,
491 })
492 .await
493 .unwrap();
494 });
495 Ok(())
496 };
497
498 let mut node_api = MockNodeAPI::new(get_dummy_node_state());
499 node_api.set_on_send_custom_message(Box::new(on_send_request));
500 node_api.set_on_stream_custom_messages(rx).await;
501
502 let transport = Arc::new(Transport::new(Arc::new(node_api)));
503 let (stop, cancel) = watch::channel(());
504 transport.start(cancel);
505 let timeout = Duration::from_millis(10);
506 let result = transport
507 .request_response::<Request, Response>(
508 String::from("test"),
509 peer_id.clone(),
510 &Request {},
511 timeout,
512 )
513 .await;
514
515 let err = result.err().unwrap();
516 assert!(matches!(err, Error::Remote { .. }));
517 match err {
518 Error::Remote(e) => {
519 assert_eq!(e.code, 1);
520 assert_eq!(e.message, String::from("error occurred"));
521 assert_eq!(e.data, Some(json!({})));
522 }
523 _ => unreachable!(),
524 };
525 let _ = stop.send(());
526 }
527
528 #[tokio::test]
529 async fn test_request_response_deserialization_error() {
530 let peer_id = vec![21];
531 let peer_id_clone = peer_id.clone();
532 let (tx, rx) = mpsc::channel(1);
533 let tx_arc = Arc::new(tx);
534 let on_send_request = move |message: CustomMessage| {
535 let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
536 let resp = RpcServerMessage {
537 jsonrpc: req.jsonrpc,
538 body: RpcServerMessageBody::Response {
539 id: req.id,
540 result: json!("cannot deserialize this"),
541 },
542 };
543 let raw_resp = serde_json::to_vec(&resp).unwrap();
544 let tx_arc = tx_arc.clone();
545 let peer_id = peer_id_clone.clone();
546 tokio::spawn(async move {
547 tx_arc
548 .send(CustomMessage {
549 message_type: LSPS0_MESSAGE_TYPE,
550 payload: raw_resp,
551 peer_id,
552 })
553 .await
554 .unwrap();
555 });
556 Ok(())
557 };
558
559 let mut node_api = MockNodeAPI::new(get_dummy_node_state());
560 node_api.set_on_send_custom_message(Box::new(on_send_request));
561 node_api.set_on_stream_custom_messages(rx).await;
562
563 let transport = Arc::new(Transport::new(Arc::new(node_api)));
564 let (stop, cancel) = watch::channel(());
565 transport.start(cancel);
566 let timeout = Duration::from_millis(10);
567 let result = transport
568 .request_response::<Request, Response>(
569 String::from("test"),
570 peer_id.clone(),
571 &Request {},
572 timeout,
573 )
574 .await;
575
576 let err = result.err().unwrap();
577 assert!(matches!(err, Error::Deserialization { .. }));
578 let _ = stop.send(());
579 }
580
581 #[tokio::test]
582 async fn test_request_response_different_id() {
583 let peer_id = vec![21];
584 let peer_id_clone = peer_id.clone();
585 let (tx, rx) = mpsc::channel(1);
586 let tx_arc = Arc::new(tx);
587 let on_send_request = move |message: CustomMessage| {
588 let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
589 let resp = RpcServerMessage {
590 jsonrpc: req.jsonrpc,
591 body: RpcServerMessageBody::Response {
592 id: String::from("different id"),
593 result: json!({}),
594 },
595 };
596 let raw_resp = serde_json::to_vec(&resp).unwrap();
597 let tx_arc = tx_arc.clone();
598 let peer_id = peer_id_clone.clone();
599 tokio::spawn(async move {
600 tx_arc
601 .send(CustomMessage {
602 message_type: LSPS0_MESSAGE_TYPE,
603 payload: raw_resp,
604 peer_id,
605 })
606 .await
607 .unwrap();
608 });
609 Ok(())
610 };
611
612 let mut node_api = MockNodeAPI::new(get_dummy_node_state());
613 node_api.set_on_send_custom_message(Box::new(on_send_request));
614 node_api.set_on_stream_custom_messages(rx).await;
615
616 let transport = Arc::new(Transport::new(Arc::new(node_api)));
617 let (stop, cancel) = watch::channel(());
618 transport.start(cancel);
619 let timeout = Duration::from_millis(10);
620 let result = transport
621 .request_response::<Request, Response>(
622 String::from("test"),
623 peer_id.clone(),
624 &Request {},
625 timeout,
626 )
627 .await;
628 assert!(matches!(result.err().unwrap(), Error::Timeout));
629 let _ = stop.send(());
630 }
631
632 #[tokio::test]
633 async fn test_request_response_not_lsps0() {
634 let peer_id = vec![21];
635 let peer_id_clone = peer_id.clone();
636 let (tx, rx) = mpsc::channel(1);
637 let tx_arc = Arc::new(tx);
638 let on_send_request = move |message: CustomMessage| {
639 let req = serde_json::from_slice::<RpcRequest<Request>>(&message.payload).unwrap();
640 let resp = RpcServerMessage {
641 jsonrpc: req.jsonrpc,
642 body: RpcServerMessageBody::Response {
643 id: req.id,
644 result: json!({}),
645 },
646 };
647 let raw_resp = serde_json::to_vec(&resp).unwrap();
648 let tx_arc = tx_arc.clone();
649 let peer_id = peer_id_clone.clone();
650 tokio::spawn(async move {
651 tx_arc
652 .send(CustomMessage {
653 message_type: LSPS0_MESSAGE_TYPE + 1,
654 payload: raw_resp,
655 peer_id,
656 })
657 .await
658 .unwrap();
659 });
660 Ok(())
661 };
662
663 let mut node_api = MockNodeAPI::new(get_dummy_node_state());
664 node_api.set_on_send_custom_message(Box::new(on_send_request));
665 node_api.set_on_stream_custom_messages(rx).await;
666
667 let transport = Arc::new(Transport::new(Arc::new(node_api)));
668 let (stop, cancel) = watch::channel(());
669 transport.start(cancel);
670 let timeout = Duration::from_millis(10);
671 let result = transport
672 .request_response::<Request, Response>(
673 String::from("test"),
674 peer_id.clone(),
675 &Request {},
676 timeout,
677 )
678 .await;
679 assert!(matches!(result.err().unwrap(), Error::Timeout));
680 let _ = stop.send(());
681 }
682
683 #[tokio::test]
684 async fn test_notification_success() {
685 let method = String::from("test");
686 let peer_id = vec![21];
687 let (tx, rx) = mpsc::channel(1);
688 let mut node_api = MockNodeAPI::new(get_dummy_node_state());
689 node_api.set_on_stream_custom_messages(rx).await;
690
691 let transport = Arc::new(Transport::new(Arc::new(node_api)));
692 let (stop, cancel) = watch::channel(());
693 transport.start(cancel);
694 let mut stream = transport
695 .stream_notifications::<Notification>(method, peer_id.clone())
696 .await
697 .unwrap();
698 let payload = RpcServerMessage {
699 jsonrpc: String::from("2.0"),
700 body: RpcServerMessageBody::Notification {
701 method: String::from("test"),
702 params: json!({}),
703 },
704 };
705 let raw_payload = serde_json::to_vec(&payload).unwrap();
706 tx.send(CustomMessage {
707 message_type: LSPS0_MESSAGE_TYPE,
708 payload: raw_payload,
709 peer_id: peer_id.clone(),
710 })
711 .await
712 .unwrap();
713 stream.recv().await.unwrap();
714 let _ = stop.send(());
715 }
716
717 #[tokio::test]
718 async fn test_notification_different_node() {
719 let method = String::from("test");
720 let peer_id = vec![21];
721 let (tx, rx) = mpsc::channel(1);
722 let mut node_api = MockNodeAPI::new(get_dummy_node_state());
723 node_api.set_on_stream_custom_messages(rx).await;
724
725 let transport = Arc::new(Transport::new(Arc::new(node_api)));
726 let (stop, cancel) = watch::channel(());
727 transport.start(cancel);
728 let mut stream = transport
729 .stream_notifications::<Notification>(method, peer_id.clone())
730 .await
731 .unwrap();
732 let payload = RpcServerMessage {
733 jsonrpc: String::from("2.0"),
734 body: RpcServerMessageBody::Notification {
735 method: String::from("test"),
736 params: json!({}),
737 },
738 };
739 let raw_payload = serde_json::to_vec(&payload).unwrap();
740 tx.send(CustomMessage {
741 message_type: LSPS0_MESSAGE_TYPE,
742 payload: raw_payload,
743 peer_id: vec![22],
744 })
745 .await
746 .unwrap();
747 let a = stream.try_recv();
748 assert!(a.is_err());
749 let _ = stop.send(());
750 }
751}