Add test coverage for serialization of malformed HTLCs.
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 5bee15c97247308fb5aa47f0da4e2e724338b400..da5f4394a4e258d3cdab3224214a70646f10f813 100644 (file)
@@ -111,6 +111,7 @@ use crate::ln::script::ShutdownScript;
 
 /// Information about where a received HTLC('s onion) has indicated the HTLC should go.
 #[derive(Clone)] // See Channel::revoke_and_ack for why, tl;dr: Rust bug
+#[cfg_attr(test, derive(Debug, PartialEq))]
 pub enum PendingHTLCRouting {
        /// An HTLC which should be forwarded on to another node.
        Forward {
@@ -189,7 +190,7 @@ pub enum PendingHTLCRouting {
 }
 
 /// Information used to forward or fail this HTLC that is being forwarded within a blinded path.
-#[derive(Clone, Copy, Hash, PartialEq, Eq)]
+#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
 pub struct BlindedForward {
        /// The `blinding_point` that was set in the inbound [`msgs::UpdateAddHTLC`], or in the inbound
        /// onion payload if we're the introduction node. Useful for calculating the next hop's
@@ -213,6 +214,7 @@ impl PendingHTLCRouting {
 /// Information about an incoming HTLC, including the [`PendingHTLCRouting`] describing where it
 /// should go next.
 #[derive(Clone)] // See Channel::revoke_and_ack for why, tl;dr: Rust bug
+#[cfg_attr(test, derive(Debug, PartialEq))]
 pub struct PendingHTLCInfo {
        /// Further routing details based on whether the HTLC is being forwarded or received.
        pub routing: PendingHTLCRouting,
@@ -267,6 +269,7 @@ pub(super) enum PendingHTLCStatus {
        Fail(HTLCFailureMsg),
 }
 
+#[cfg_attr(test, derive(Clone, Debug, PartialEq))]
 pub(super) struct PendingAddHTLCInfo {
        pub(super) forward_info: PendingHTLCInfo,
 
@@ -282,6 +285,7 @@ pub(super) struct PendingAddHTLCInfo {
        prev_user_channel_id: u128,
 }
 
+#[cfg_attr(test, derive(Clone, Debug, PartialEq))]
 pub(super) enum HTLCForwardInfo {
        AddHTLC(PendingAddHTLCInfo),
        FailHTLC {
@@ -11056,12 +11060,14 @@ mod tests {
        use crate::events::{Event, HTLCDestination, MessageSendEvent, MessageSendEventsProvider, ClosureReason};
        use crate::ln::{PaymentPreimage, PaymentHash, PaymentSecret};
        use crate::ln::ChannelId;
-       use crate::ln::channelmanager::{create_recv_pending_htlc_info, inbound_payment, PaymentId, PaymentSendFailure, RecipientOnionFields, InterceptId};
+       use crate::ln::channelmanager::{create_recv_pending_htlc_info, HTLCForwardInfo, inbound_payment, PaymentId, PaymentSendFailure, RecipientOnionFields, InterceptId};
        use crate::ln::functional_test_utils::*;
        use crate::ln::msgs::{self, ErrorAction};
        use crate::ln::msgs::ChannelMessageHandler;
+       use crate::prelude::*;
        use crate::routing::router::{PaymentParameters, RouteParameters, find_route};
        use crate::util::errors::APIError;
+       use crate::util::ser::Writeable;
        use crate::util::test_utils;
        use crate::util::config::{ChannelConfig, ChannelConfigUpdate};
        use crate::sign::EntropySource;
@@ -12336,6 +12342,63 @@ mod tests {
                        check_spends!(txn[0], funding_tx);
                }
        }
+
+       #[test]
+       fn test_malformed_forward_htlcs_ser() {
+               // Ensure that `HTLCForwardInfo::FailMalformedHTLC`s are (de)serialized properly.
+               let chanmon_cfg = create_chanmon_cfgs(1);
+               let node_cfg = create_node_cfgs(1, &chanmon_cfg);
+               let persister;
+               let chain_monitor;
+               let chanmgrs = create_node_chanmgrs(1, &node_cfg, &[None]);
+               let deserialized_chanmgr;
+               let mut nodes = create_network(1, &node_cfg, &chanmgrs);
+
+               let dummy_failed_htlc = |htlc_id| {
+                       HTLCForwardInfo::FailHTLC { htlc_id, err_packet: msgs::OnionErrorPacket { data: vec![42] }, }
+               };
+               let dummy_malformed_htlc = |htlc_id| {
+                       HTLCForwardInfo::FailMalformedHTLC { htlc_id, failure_code: 0x4000, sha256_of_onion: [0; 32] }
+               };
+
+               let dummy_htlcs_1: Vec<HTLCForwardInfo> = (1..10).map(|htlc_id| {
+                       if htlc_id % 2 == 0 {
+                               dummy_failed_htlc(htlc_id)
+                       } else {
+                               dummy_malformed_htlc(htlc_id)
+                       }
+               }).collect();
+
+               let dummy_htlcs_2: Vec<HTLCForwardInfo> = (1..10).map(|htlc_id| {
+                       if htlc_id % 2 == 1 {
+                               dummy_failed_htlc(htlc_id)
+                       } else {
+                               dummy_malformed_htlc(htlc_id)
+                       }
+               }).collect();
+
+
+               let (scid_1, scid_2) = (42, 43);
+               let mut forward_htlcs = HashMap::new();
+               forward_htlcs.insert(scid_1, dummy_htlcs_1.clone());
+               forward_htlcs.insert(scid_2, dummy_htlcs_2.clone());
+
+               let mut chanmgr_fwd_htlcs = nodes[0].node.forward_htlcs.lock().unwrap();
+               *chanmgr_fwd_htlcs = forward_htlcs.clone();
+               core::mem::drop(chanmgr_fwd_htlcs);
+
+               reload_node!(nodes[0], nodes[0].node.encode(), &[], persister, chain_monitor, deserialized_chanmgr);
+
+               let mut deserialized_fwd_htlcs = nodes[0].node.forward_htlcs.lock().unwrap();
+               for scid in [scid_1, scid_2].iter() {
+                       let deserialized_htlcs = deserialized_fwd_htlcs.remove(scid).unwrap();
+                       assert_eq!(forward_htlcs.remove(scid).unwrap(), deserialized_htlcs);
+               }
+               assert!(deserialized_fwd_htlcs.is_empty());
+               core::mem::drop(deserialized_fwd_htlcs);
+
+               expect_pending_htlcs_forwardable!(nodes[0]);
+       }
 }
 
 #[cfg(ldk_bench)]