Merge pull request #2045 from wpaulino/fix-broken-commitment-test-vectors
[rust-lightning] / lightning / src / ln / onion_utils.rs
index a44e9b37d04a260f85c7de68e4da7c97f9c12e05..2916829f259ec58ee110b95fdf14fb16a702a401 100644 (file)
@@ -15,7 +15,7 @@ use crate::routing::gossip::NetworkUpdate;
 use crate::routing::router::RouteHop;
 use crate::util::chacha20::{ChaCha20, ChaChaReader};
 use crate::util::errors::{self, APIError};
-use crate::util::ser::{Readable, ReadableArgs, Writeable, LengthCalculatingWriter};
+use crate::util::ser::{Readable, ReadableArgs, Writeable, Writer, LengthCalculatingWriter};
 use crate::util::logger::Logger;
 
 use bitcoin::hashes::{Hash, HashEngine};
@@ -182,11 +182,11 @@ pub(super) fn build_onion_payloads(path: &Vec<RouteHop>, total_msat: u64, paymen
                });
                cur_value_msat += hop.fee_msat;
                if cur_value_msat >= 21000000 * 100000000 * 1000 {
-                       return Err(APIError::RouteError{err: "Channel fees overflowed?"});
+                       return Err(APIError::InvalidRoute{err: "Channel fees overflowed?".to_owned()});
                }
                cur_cltv += hop.cltv_expiry_delta as u32;
                if cur_cltv >= 500000000 {
-                       return Err(APIError::RouteError{err: "Channel CLTV overflowed?"});
+                       return Err(APIError::InvalidRoute{err: "Channel CLTV overflowed?".to_owned()});
                }
                last_short_channel_id = hop.short_channel_id;
        }
@@ -382,7 +382,7 @@ pub(super) fn build_failure_packet(shared_secret: &[u8], failure_type: u16, fail
        packet
 }
 
-#[inline]
+#[cfg(test)]
 pub(super) fn build_first_hop_failure_packet(shared_secret: &[u8], failure_type: u16, failure_data: &[u8]) -> msgs::OnionErrorPacket {
        let failure_packet = build_failure_packet(shared_secret, failure_type, failure_data);
        encrypt_failure_packet(shared_secret, &failure_packet.encode()[..])
@@ -593,7 +593,10 @@ pub(super) fn process_onion_failure<T: secp256k1::Signing, L: Deref>(secp_ctx: &
 }
 
 #[derive(Clone)] // See Channel::revoke_and_ack for why, tl;dr: Rust bug
-pub(super) enum HTLCFailReason {
+pub(super) struct HTLCFailReason(HTLCFailReasonRepr);
+
+#[derive(Clone)] // See Channel::revoke_and_ack for why, tl;dr: Rust bug
+enum HTLCFailReasonRepr {
        LightningError {
                err: msgs::OnionErrorPacket,
        },
@@ -605,18 +608,29 @@ pub(super) enum HTLCFailReason {
 
 impl core::fmt::Debug for HTLCFailReason {
        fn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> {
-               match self {
-                       HTLCFailReason::Reason { ref failure_code, .. } => {
+               match self.0 {
+                       HTLCFailReasonRepr::Reason { ref failure_code, .. } => {
                                write!(f, "HTLC error code {}", failure_code)
                        },
-                       HTLCFailReason::LightningError { .. } => {
+                       HTLCFailReasonRepr::LightningError { .. } => {
                                write!(f, "pre-built LightningError")
                        }
                }
        }
 }
 
-impl_writeable_tlv_based_enum!(HTLCFailReason,
+impl Writeable for HTLCFailReason {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), crate::io::Error> {
+               self.0.write(writer)
+       }
+}
+impl Readable for HTLCFailReason {
+       fn read<R: Read>(reader: &mut R) -> Result<Self, msgs::DecodeError> {
+               Ok(Self(Readable::read(reader)?))
+       }
+}
+
+impl_writeable_tlv_based_enum!(HTLCFailReasonRepr,
        (0, LightningError) => {
                (0, err, required),
        },
@@ -628,21 +642,59 @@ impl_writeable_tlv_based_enum!(HTLCFailReason,
 
 impl HTLCFailReason {
        pub(super) fn reason(failure_code: u16, data: Vec<u8>) -> Self {
-               Self::Reason { failure_code, data }
+               const BADONION: u16 = 0x8000;
+               const PERM: u16 = 0x4000;
+               const NODE: u16 = 0x2000;
+               const UPDATE: u16 = 0x1000;
+
+                    if failure_code == 1  | PERM { debug_assert!(data.is_empty()) }
+               else if failure_code == 2  | NODE { debug_assert!(data.is_empty()) }
+               else if failure_code == 2  | PERM | NODE { debug_assert!(data.is_empty()) }
+               else if failure_code == 3  | PERM | NODE { debug_assert!(data.is_empty()) }
+               else if failure_code == 4  | BADONION | PERM { debug_assert_eq!(data.len(), 32) }
+               else if failure_code == 5  | BADONION | PERM { debug_assert_eq!(data.len(), 32) }
+               else if failure_code == 6  | BADONION | PERM { debug_assert_eq!(data.len(), 32) }
+               else if failure_code == 7  | UPDATE {
+                       debug_assert_eq!(data.len() - 2, u16::from_be_bytes(data[0..2].try_into().unwrap()) as usize) }
+               else if failure_code == 8  | PERM { debug_assert!(data.is_empty()) }
+               else if failure_code == 9  | PERM { debug_assert!(data.is_empty()) }
+               else if failure_code == 10 | PERM { debug_assert!(data.is_empty()) }
+               else if failure_code == 11 | UPDATE {
+                       debug_assert_eq!(data.len() - 2 - 8, u16::from_be_bytes(data[8..10].try_into().unwrap()) as usize) }
+               else if failure_code == 12 | UPDATE {
+                       debug_assert_eq!(data.len() - 2 - 8, u16::from_be_bytes(data[8..10].try_into().unwrap()) as usize) }
+               else if failure_code == 13 | UPDATE {
+                       debug_assert_eq!(data.len() - 2 - 4, u16::from_be_bytes(data[4..6].try_into().unwrap()) as usize) }
+               else if failure_code == 14 | UPDATE {
+                       debug_assert_eq!(data.len() - 2, u16::from_be_bytes(data[0..2].try_into().unwrap()) as usize) }
+               else if failure_code == 15 | PERM { debug_assert_eq!(data.len(), 12) }
+               else if failure_code == 18 { debug_assert_eq!(data.len(), 4) }
+               else if failure_code == 19 { debug_assert_eq!(data.len(), 8) }
+               else if failure_code == 20 | UPDATE {
+                       debug_assert_eq!(data.len() - 2 - 2, u16::from_be_bytes(data[2..4].try_into().unwrap()) as usize) }
+               else if failure_code == 21 { debug_assert!(data.is_empty()) }
+               else if failure_code == 22 | PERM { debug_assert!(data.len() <= 11) }
+               else if failure_code == 23 { debug_assert!(data.is_empty()) }
+               else if failure_code & BADONION != 0 {
+                       // We set some bogus BADONION failure codes in test, so ignore unknown ones.
+               }
+               else { debug_assert!(false, "Unknown failure code: {}", failure_code) }
+
+               Self(HTLCFailReasonRepr::Reason { failure_code, data })
        }
 
        pub(super) fn from_failure_code(failure_code: u16) -> Self {
-               Self::Reason { failure_code, data: Vec::new() }
+               Self::reason(failure_code, Vec::new())
        }
 
        pub(super) fn from_msg(msg: &msgs::UpdateFailHTLC) -> Self {
-               Self::LightningError { err: msg.reason.clone() }
+               Self(HTLCFailReasonRepr::LightningError { err: msg.reason.clone() })
        }
 
        pub(super) fn get_encrypted_failure_packet(&self, incoming_packet_shared_secret: &[u8; 32], phantom_shared_secret: &Option<[u8; 32]>)
        -> msgs::OnionErrorPacket {
-               match self {
-                       HTLCFailReason::Reason { ref failure_code, ref data } => {
+               match self.0 {
+                       HTLCFailReasonRepr::Reason { ref failure_code, ref data } => {
                                if let Some(phantom_ss) = phantom_shared_secret {
                                        let phantom_packet = build_failure_packet(phantom_ss, *failure_code, &data[..]).encode();
                                        let encrypted_phantom_packet = encrypt_failure_packet(phantom_ss, &phantom_packet);
@@ -652,7 +704,7 @@ impl HTLCFailReason {
                                        encrypt_failure_packet(incoming_packet_shared_secret, &packet)
                                }
                        },
-                       HTLCFailReason::LightningError { err } => {
+                       HTLCFailReasonRepr::LightningError { ref err } => {
                                encrypt_failure_packet(incoming_packet_shared_secret, &err.data)
                        }
                }
@@ -662,11 +714,11 @@ impl HTLCFailReason {
                &self, secp_ctx: &Secp256k1<T>, logger: &L, htlc_source: &HTLCSource
        ) -> (Option<NetworkUpdate>, Option<u64>, bool, Option<u16>, Option<Vec<u8>>)
        where L::Target: Logger {
-               match self {
-                       HTLCFailReason::LightningError { ref err } => {
+               match self.0 {
+                       HTLCFailReasonRepr::LightningError { ref err } => {
                                process_onion_failure(secp_ctx, logger, &htlc_source, err.data.clone())
                        },
-                       HTLCFailReason::Reason { ref failure_code, ref data, .. } => {
+                       HTLCFailReasonRepr::Reason { ref failure_code, ref data, .. } => {
                                // we get a fail_malformed_htlc from the first hop
                                // TODO: We'd like to generate a NetworkUpdate for temporary
                                // failures here, but that would be insufficient as find_route