]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Merge pull request #87 from savil/editorconfig
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Wed, 25 Jul 2018 22:08:38 +0000 (18:08 -0400)
committerGitHub <noreply@github.com>
Wed, 25 Jul 2018 22:08:38 +0000 (18:08 -0400)
add .editorconfig to ensure we use tabs, not spaces

fuzz/Cargo.toml
fuzz/fuzz_targets/channel_target.rs
fuzz/fuzz_targets/router_target.rs [new file with mode: 0644]
src/ln/msgs.rs
src/ln/router.rs

index 14610540510f394fec670a92dadeb0f5ad17e3d8..22f4bdcc181a3706ddbaf6260e16ea5fa026aa55 100644 (file)
@@ -42,6 +42,10 @@ path = "fuzz_targets/channel_target.rs"
 name = "full_stack_target"
 path = "fuzz_targets/full_stack_target.rs"
 
+[[bin]]
+name = "router_target"
+path = "fuzz_targets/router_target.rs"
+
 [[bin]]
 name = "chanmon_deser_target"
 path = "fuzz_targets/chanmon_deser_target.rs"
index 27891e11aa9aad2e96783659158fbc4d81aa5eac..1bcc3d708c677ecd52421264785146132d6d112e 100644 (file)
@@ -10,7 +10,7 @@ use bitcoin::network::serialize::{serialize, BitcoinHash};
 use lightning::ln::channel::{Channel, ChannelKeys};
 use lightning::ln::channelmanager::{HTLCFailReason, PendingForwardHTLCInfo};
 use lightning::ln::msgs;
-use lightning::ln::msgs::MsgDecodable;
+use lightning::ln::msgs::{MsgDecodable, ErrorAction};
 use lightning::chain::chaininterface::{FeeEstimator, ConfirmationTarget};
 use lightning::chain::transaction::OutPoint;
 use lightning::util::reset_rng_state;
@@ -120,7 +120,8 @@ pub fn do_test(data: &[u8]) {
                                        msgs::DecodeError::BadSignature => return,
                                        msgs::DecodeError::BadText => return,
                                        msgs::DecodeError::ExtraAddressesPerType => return,
-                                       msgs::DecodeError::WrongLength => panic!("We picked the length..."),
+                                       msgs::DecodeError::BadLengthDescriptor => return,
+                                       msgs::DecodeError::ShortRead => panic!("We picked the length..."),
                                }
                        }
                }
@@ -141,7 +142,8 @@ pub fn do_test(data: &[u8]) {
                                                msgs::DecodeError::BadSignature => return,
                                                msgs::DecodeError::BadText => return,
                                                msgs::DecodeError::ExtraAddressesPerType => return,
-                                               msgs::DecodeError::WrongLength => panic!("We picked the length..."),
+                                               msgs::DecodeError::BadLengthDescriptor => return,
+                                               msgs::DecodeError::ShortRead => panic!("We picked the length..."),
                                        }
                                }
                        }
@@ -237,10 +239,25 @@ pub fn do_test(data: &[u8]) {
        let funding_locked = decode_msg!(msgs::FundingLocked, 32+33);
        return_err!(channel.funding_locked(&funding_locked));
 
+       macro_rules! test_err {
+               ($expr: expr) => {
+                       match $expr {
+                               Ok(r) => Some(r),
+                               Err(e) => match e.action {
+                                       None => return,
+                                       Some(ErrorAction::UpdateFailHTLC {..}) => None,
+                                       Some(ErrorAction::DisconnectPeer {..}) => return,
+                                       Some(ErrorAction::IgnoreError) => None,
+                                       Some(ErrorAction::SendErrorMessage {..}) => None,
+                               },
+                       }
+               }
+       }
+
        loop {
                match get_slice!(1)[0] {
                        0 => {
-                               return_err!(channel.send_htlc(slice_to_be64(get_slice!(8)), [42; 32], slice_to_be32(get_slice!(4)), msgs::OnionPacket {
+                               test_err!(channel.send_htlc(slice_to_be64(get_slice!(8)), [42; 32], slice_to_be32(get_slice!(4)), msgs::OnionPacket {
                                        version: get_slice!(1)[0],
                                        public_key: get_pubkey!(),
                                        hop_data: [0; 20*65],
@@ -248,44 +265,45 @@ pub fn do_test(data: &[u8]) {
                                }));
                        },
                        1 => {
-                               return_err!(channel.send_commitment());
+                               test_err!(channel.send_commitment());
                        },
                        2 => {
                                let update_add_htlc = decode_msg!(msgs::UpdateAddHTLC, 32+8+8+32+4+4+33+20*65+32);
-                               return_err!(channel.update_add_htlc(&update_add_htlc, PendingForwardHTLCInfo::dummy()));
+                               test_err!(channel.update_add_htlc(&update_add_htlc, PendingForwardHTLCInfo::dummy()));
                        },
                        3 => {
                                let update_fulfill_htlc = decode_msg!(msgs::UpdateFulfillHTLC, 32 + 8 + 32);
-                               return_err!(channel.update_fulfill_htlc(&update_fulfill_htlc));
+                               test_err!(channel.update_fulfill_htlc(&update_fulfill_htlc));
                        },
                        4 => {
                                let update_fail_htlc = decode_msg_with_len16!(msgs::UpdateFailHTLC, 32 + 8, 1);
-                               return_err!(channel.update_fail_htlc(&update_fail_htlc, HTLCFailReason::dummy()));
+                               test_err!(channel.update_fail_htlc(&update_fail_htlc, HTLCFailReason::dummy()));
                        },
                        5 => {
                                let update_fail_malformed_htlc = decode_msg!(msgs::UpdateFailMalformedHTLC, 32+8+32+2);
-                               return_err!(channel.update_fail_malformed_htlc(&update_fail_malformed_htlc, HTLCFailReason::dummy()));
+                               test_err!(channel.update_fail_malformed_htlc(&update_fail_malformed_htlc, HTLCFailReason::dummy()));
                        },
                        6 => {
                                let commitment_signed = decode_msg_with_len16!(msgs::CommitmentSigned, 32+64, 64);
-                               return_err!(channel.commitment_signed(&commitment_signed));
+                               test_err!(channel.commitment_signed(&commitment_signed));
                        },
                        7 => {
                                let revoke_and_ack = decode_msg!(msgs::RevokeAndACK, 32+32+33);
-                               return_err!(channel.revoke_and_ack(&revoke_and_ack));
+                               test_err!(channel.revoke_and_ack(&revoke_and_ack));
                        },
                        8 => {
                                let update_fee = decode_msg!(msgs::UpdateFee, 32+4);
-                               return_err!(channel.update_fee(&fee_est, &update_fee));
+                               test_err!(channel.update_fee(&fee_est, &update_fee));
                        },
                        9 => {
                                let shutdown = decode_msg_with_len16!(msgs::Shutdown, 32, 1);
-                               return_err!(channel.shutdown(&fee_est, &shutdown));
+                               test_err!(channel.shutdown(&fee_est, &shutdown));
                                if channel.is_shutdown() { return; }
                        },
                        10 => {
                                let closing_signed = decode_msg!(msgs::ClosingSigned, 32+8+64);
-                               if return_err!(channel.closing_signed(&fee_est, &closing_signed)).1.is_some() {
+                               let sign_res = test_err!(channel.closing_signed(&fee_est, &closing_signed));
+                               if sign_res.is_some() && sign_res.unwrap().1.is_some() {
                                        assert!(channel.is_shutdown());
                                        return;
                                }
diff --git a/fuzz/fuzz_targets/router_target.rs b/fuzz/fuzz_targets/router_target.rs
new file mode 100644 (file)
index 0000000..13733ad
--- /dev/null
@@ -0,0 +1,219 @@
+extern crate bitcoin;
+extern crate lightning;
+extern crate secp256k1;
+
+use lightning::ln::channelmanager::ChannelDetails;
+use lightning::ln::msgs;
+use lightning::ln::msgs::{MsgDecodable, RoutingMessageHandler};
+use lightning::ln::router::{Router, RouteHint};
+use lightning::util::reset_rng_state;
+
+use secp256k1::key::PublicKey;
+use secp256k1::Secp256k1;
+
+#[inline]
+pub fn slice_to_be16(v: &[u8]) -> u16 {
+       ((v[0] as u16) << 8*1) |
+       ((v[1] as u16) << 8*0)
+}
+
+#[inline]
+pub fn slice_to_be32(v: &[u8]) -> u32 {
+       ((v[0] as u32) << 8*3) |
+       ((v[1] as u32) << 8*2) |
+       ((v[2] as u32) << 8*1) |
+       ((v[3] as u32) << 8*0)
+}
+
+#[inline]
+pub fn slice_to_be64(v: &[u8]) -> u64 {
+       ((v[0] as u64) << 8*7) |
+       ((v[1] as u64) << 8*6) |
+       ((v[2] as u64) << 8*5) |
+       ((v[3] as u64) << 8*4) |
+       ((v[4] as u64) << 8*3) |
+       ((v[5] as u64) << 8*2) |
+       ((v[6] as u64) << 8*1) |
+       ((v[7] as u64) << 8*0)
+}
+
+#[inline]
+pub fn do_test(data: &[u8]) {
+       reset_rng_state();
+
+       let mut read_pos = 0;
+       macro_rules! get_slice_nonadvancing {
+               ($len: expr) => {
+                       {
+                               if data.len() < read_pos + $len as usize {
+                                       return;
+                               }
+                               &data[read_pos..read_pos + $len as usize]
+                       }
+               }
+       }
+       macro_rules! get_slice {
+               ($len: expr) => {
+                       {
+                               let res = get_slice_nonadvancing!($len);
+                               read_pos += $len;
+                               res
+                       }
+               }
+       }
+
+       macro_rules! decode_msg {
+               ($MsgType: path, $len: expr) => {
+                       match <($MsgType)>::decode(get_slice!($len)) {
+                               Ok(msg) => msg,
+                               Err(e) => match e {
+                                       msgs::DecodeError::UnknownRealmByte => return,
+                                       msgs::DecodeError::BadPublicKey => return,
+                                       msgs::DecodeError::BadSignature => return,
+                                       msgs::DecodeError::BadText => return,
+                                       msgs::DecodeError::ExtraAddressesPerType => return,
+                                       msgs::DecodeError::BadLengthDescriptor => return,
+                                       msgs::DecodeError::ShortRead => panic!("We picked the length..."),
+                               }
+                       }
+               }
+       }
+
+       macro_rules! decode_msg_with_len16 {
+               ($MsgType: path, $begin_len: expr, $excess: expr) => {
+                       {
+                               let extra_len = slice_to_be16(&get_slice_nonadvancing!($begin_len as usize + 2)[$begin_len..$begin_len + 2]);
+                               decode_msg!($MsgType, $begin_len as usize + 2 + (extra_len as usize) + $excess)
+                       }
+               }
+       }
+
+       let secp_ctx = Secp256k1::new();
+       macro_rules! get_pubkey {
+               () => {
+                       match PublicKey::from_slice(&secp_ctx, get_slice!(33)) {
+                               Ok(key) => key,
+                               Err(_) => return,
+                       }
+               }
+       }
+
+       let our_pubkey = get_pubkey!();
+       let router = Router::new(our_pubkey.clone());
+
+       loop {
+               match get_slice!(1)[0] {
+                       0 => {
+                               let start_len = slice_to_be16(&get_slice_nonadvancing!(64 + 2)[64..64 + 2]) as usize;
+                               let addr_len = slice_to_be16(&get_slice_nonadvancing!(64+start_len+2 + 74)[64+start_len+2 + 72..64+start_len+2 + 74]);
+                               if addr_len > (37+1)*4 {
+                                       return;
+                               }
+                               let _ = router.handle_node_announcement(&decode_msg_with_len16!(msgs::NodeAnnouncement, 64, 288));
+                       },
+                       1 => {
+                               let _ = router.handle_channel_announcement(&decode_msg_with_len16!(msgs::ChannelAnnouncement, 64*4, 32+8+33*4));
+                       },
+                       2 => {
+                               let _ = router.handle_channel_update(&decode_msg!(msgs::ChannelUpdate, 128));
+                       },
+                       3 => {
+                               match get_slice!(1)[0] {
+                                       0 => {
+                                               router.handle_htlc_fail_channel_update(&msgs::HTLCFailChannelUpdate::ChannelUpdateMessage {msg: decode_msg!(msgs::ChannelUpdate, 128)});
+                                       },
+                                       1 => {
+                                               let short_channel_id = slice_to_be64(get_slice!(8));
+                                               router.handle_htlc_fail_channel_update(&msgs::HTLCFailChannelUpdate::ChannelClosed {short_channel_id});
+                                       },
+                                       _ => return,
+                               }
+                       },
+                       4 => {
+                               let target = get_pubkey!();
+                               let mut first_hops_vec = Vec::new();
+                               let first_hops = match get_slice!(1)[0] {
+                                       0 => None,
+                                       1 => {
+                                               let count = slice_to_be16(get_slice!(2));
+                                               for _ in 0..count {
+                                                       first_hops_vec.push(ChannelDetails {
+                                                               channel_id: [0; 32],
+                                                               short_channel_id: Some(slice_to_be64(get_slice!(8))),
+                                                               remote_network_id: get_pubkey!(),
+                                                               channel_value_satoshis: slice_to_be64(get_slice!(8)),
+                                                               user_id: 0,
+                                                       });
+                                               }
+                                               Some(&first_hops_vec[..])
+                                       },
+                                       _ => return,
+                               };
+                               let mut last_hops_vec = Vec::new();
+                               let last_hops = {
+                                       let count = slice_to_be16(get_slice!(2));
+                                       for _ in 0..count {
+                                               last_hops_vec.push(RouteHint {
+                                                       src_node_id: get_pubkey!(),
+                                                       short_channel_id: slice_to_be64(get_slice!(8)),
+                                                       fee_base_msat: slice_to_be64(get_slice!(8)),
+                                                       fee_proportional_millionths: slice_to_be32(get_slice!(4)),
+                                                       cltv_expiry_delta: slice_to_be16(get_slice!(2)),
+                                                       htlc_minimum_msat: slice_to_be64(get_slice!(8)),
+                                               });
+                                       }
+                                       &last_hops_vec[..]
+                               };
+                               let _ = router.get_route(&target, first_hops, last_hops, slice_to_be64(get_slice!(8)), slice_to_be32(get_slice!(4)));
+                       },
+                       _ => return,
+               }
+       }
+}
+
+#[cfg(feature = "afl")]
+extern crate afl;
+#[cfg(feature = "afl")]
+fn main() {
+       afl::read_stdio_bytes(|data| {
+               do_test(&data);
+       });
+}
+
+#[cfg(feature = "honggfuzz")]
+#[macro_use] extern crate honggfuzz;
+#[cfg(feature = "honggfuzz")]
+fn main() {
+       loop {
+               fuzz!(|data| {
+                       do_test(data);
+               });
+       }
+}
+
+#[cfg(test)]
+mod tests {
+       fn extend_vec_from_hex(hex: &str, out: &mut Vec<u8>) {
+               let mut b = 0;
+               for (idx, c) in hex.as_bytes().iter().enumerate() {
+                       b <<= 4;
+                       match *c {
+                               b'A'...b'F' => b |= c - b'A' + 10,
+                               b'a'...b'f' => b |= c - b'a' + 10,
+                               b'0'...b'9' => b |= c - b'0',
+                               _ => panic!("Bad hex"),
+                       }
+                       if (idx & 1) == 1 {
+                               out.push(b);
+                               b = 0;
+                       }
+               }
+       }
+
+       #[test]
+       fn duplicate_crash() {
+               let mut a = Vec::new();
+               extend_vec_from_hex("00", &mut a);
+               super::do_test(&a);
+       }
+}
index 53aa23972889e4f6734d1308b2c30a8ccc0119ad..7f502530a9db55f5177aad787380645d8d72fcf0 100644 (file)
@@ -33,10 +33,13 @@ pub enum DecodeError {
        BadSignature,
        /// Value expected to be text wasn't decodable as text
        BadText,
-       /// Buffer not of right length (either too short or too long)
-       WrongLength,
+       /// Buffer too short
+       ShortRead,
        /// node_announcement included more than one address of a given type!
        ExtraAddressesPerType,
+       /// A length descriptor in the packet didn't describe the later data correctly
+       /// (currently only generated in node_announcement)
+       BadLengthDescriptor,
 }
 pub trait MsgDecodable: Sized {
        fn decode(v: &[u8]) -> Result<Self, DecodeError>;
@@ -500,8 +503,9 @@ impl Error for DecodeError {
                        DecodeError::BadPublicKey => "Invalid public key in packet",
                        DecodeError::BadSignature => "Invalid signature in packet",
                        DecodeError::BadText => "Invalid text in packet",
-                       DecodeError::WrongLength => "Data was wrong length for packet",
+                       DecodeError::ShortRead => "Packet extended beyond the provided bytes",
                        DecodeError::ExtraAddressesPerType => "More than one address of a single type",
+                       DecodeError::BadLengthDescriptor => "A length descriptor in the packet didn't describe the later data correctly",
                }
        }
 }
@@ -537,9 +541,9 @@ macro_rules! secp_signature {
 
 impl MsgDecodable for LocalFeatures {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
-               if v.len() < 2 { return Err(DecodeError::WrongLength); }
+               if v.len() < 2 { return Err(DecodeError::ShortRead); }
                let len = byte_utils::slice_to_be16(&v[0..2]) as usize;
-               if v.len() < len + 2 { return Err(DecodeError::WrongLength); }
+               if v.len() < len + 2 { return Err(DecodeError::ShortRead); }
                let mut flags = Vec::with_capacity(len);
                flags.extend_from_slice(&v[2..2 + len]);
                Ok(Self {
@@ -559,9 +563,9 @@ impl MsgEncodable for LocalFeatures {
 
 impl MsgDecodable for GlobalFeatures {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
-               if v.len() < 2 { return Err(DecodeError::WrongLength); }
+               if v.len() < 2 { return Err(DecodeError::ShortRead); }
                let len = byte_utils::slice_to_be16(&v[0..2]) as usize;
-               if v.len() < len + 2 { return Err(DecodeError::WrongLength); }
+               if v.len() < len + 2 { return Err(DecodeError::ShortRead); }
                let mut flags = Vec::with_capacity(len);
                flags.extend_from_slice(&v[2..2 + len]);
                Ok(Self {
@@ -583,7 +587,7 @@ impl MsgDecodable for Init {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                let global_features = GlobalFeatures::decode(v)?;
                if v.len() < global_features.flags.len() + 4 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let local_features = LocalFeatures::decode(&v[global_features.flags.len() + 2..])?;
                Ok(Self {
@@ -604,12 +608,12 @@ impl MsgEncodable for Init {
 impl MsgDecodable for Ping {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 4 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let ponglen = byte_utils::slice_to_be16(&v[0..2]);
                let byteslen = byte_utils::slice_to_be16(&v[2..4]);
                if v.len() < 4 + byteslen as usize {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                Ok(Self {
                        ponglen,
@@ -629,11 +633,11 @@ impl MsgEncodable for Ping {
 impl MsgDecodable for Pong {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 2 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let byteslen = byte_utils::slice_to_be16(&v[0..2]);
                if v.len() < 2 + byteslen as usize {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                Ok(Self {
                        byteslen
@@ -652,7 +656,7 @@ impl MsgEncodable for Pong {
 impl MsgDecodable for OpenChannel {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 2*32+6*8+4+2*2+6*33+1 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let ctx = Secp256k1::without_caps();
 
@@ -660,11 +664,9 @@ impl MsgDecodable for OpenChannel {
                if v.len() >= 321 {
                        let len = byte_utils::slice_to_be16(&v[319..321]) as usize;
                        if v.len() < 321+len {
-                               return Err(DecodeError::WrongLength);
+                               return Err(DecodeError::ShortRead);
                        }
                        shutdown_scriptpubkey = Some(Script::from(v[321..321+len].to_vec()));
-               } else if v.len() != 2*32+6*8+4+2*2+6*33+1 { // Message cant have 1 extra byte
-                       return Err(DecodeError::WrongLength);
                }
 
                Ok(OpenChannel {
@@ -725,7 +727,7 @@ impl MsgEncodable for OpenChannel {
 impl MsgDecodable for AcceptChannel {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+4*8+4+2*2+6*33 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let ctx = Secp256k1::without_caps();
 
@@ -733,11 +735,9 @@ impl MsgDecodable for AcceptChannel {
                if v.len() >= 272 {
                        let len = byte_utils::slice_to_be16(&v[270..272]) as usize;
                        if v.len() < 272+len {
-                               return Err(DecodeError::WrongLength);
+                               return Err(DecodeError::ShortRead);
                        }
                        shutdown_scriptpubkey = Some(Script::from(v[272..272+len].to_vec()));
-               } else if v.len() != 32+4*8+4+2*2+6*33 { // Message cant have 1 extra byte
-                       return Err(DecodeError::WrongLength);
                }
 
                let mut temporary_channel_id = [0; 32];
@@ -792,7 +792,7 @@ impl MsgEncodable for AcceptChannel {
 impl MsgDecodable for FundingCreated {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+32+2+64 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let ctx = Secp256k1::without_caps();
                let mut temporary_channel_id = [0; 32];
@@ -820,7 +820,7 @@ impl MsgEncodable for FundingCreated {
 impl MsgDecodable for FundingSigned {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+64 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let ctx = Secp256k1::without_caps();
                let mut channel_id = [0; 32];
@@ -843,7 +843,7 @@ impl MsgEncodable for FundingSigned {
 impl MsgDecodable for FundingLocked {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+33 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let ctx = Secp256k1::without_caps();
                let mut channel_id = [0; 32];
@@ -866,11 +866,11 @@ impl MsgEncodable for FundingLocked {
 impl MsgDecodable for Shutdown {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32 + 2 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let scriptlen = byte_utils::slice_to_be16(&v[32..34]) as usize;
                if v.len() < 32 + 2 + scriptlen {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let mut channel_id = [0; 32];
                channel_id[..].copy_from_slice(&v[0..32]);
@@ -893,7 +893,7 @@ impl MsgEncodable for Shutdown {
 impl MsgDecodable for ClosingSigned {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32 + 8 + 64 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let secp_ctx = Secp256k1::without_caps();
                let mut channel_id = [0; 32];
@@ -919,7 +919,7 @@ impl MsgEncodable for ClosingSigned {
 impl MsgDecodable for UpdateAddHTLC {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+8+8+32+4+1+33+20*65+32 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let mut channel_id = [0; 32];
                channel_id[..].copy_from_slice(&v[0..32]);
@@ -951,7 +951,7 @@ impl MsgEncodable for UpdateAddHTLC {
 impl MsgDecodable for UpdateFulfillHTLC {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+8+32 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let mut channel_id = [0; 32];
                channel_id[..].copy_from_slice(&v[0..32]);
@@ -977,7 +977,7 @@ impl MsgEncodable for UpdateFulfillHTLC {
 impl MsgDecodable for UpdateFailHTLC {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+8 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let mut channel_id = [0; 32];
                channel_id[..].copy_from_slice(&v[0..32]);
@@ -1002,7 +1002,7 @@ impl MsgEncodable for UpdateFailHTLC {
 impl MsgDecodable for UpdateFailMalformedHTLC {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+8+32+2 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let mut channel_id = [0; 32];
                channel_id[..].copy_from_slice(&v[0..32]);
@@ -1030,14 +1030,14 @@ impl MsgEncodable for UpdateFailMalformedHTLC {
 impl MsgDecodable for CommitmentSigned {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+64+2 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let mut channel_id = [0; 32];
                channel_id[..].copy_from_slice(&v[0..32]);
 
                let htlcs = byte_utils::slice_to_be16(&v[96..98]) as usize;
                if v.len() < 32+64+2+htlcs*64 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let mut htlc_signatures = Vec::with_capacity(htlcs);
                let secp_ctx = Secp256k1::without_caps();
@@ -1068,7 +1068,7 @@ impl MsgEncodable for CommitmentSigned {
 impl MsgDecodable for RevokeAndACK {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+32+33 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let mut channel_id = [0; 32];
                channel_id[..].copy_from_slice(&v[0..32]);
@@ -1095,7 +1095,7 @@ impl MsgEncodable for RevokeAndACK {
 impl MsgDecodable for UpdateFee {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+4 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let mut channel_id = [0; 32];
                channel_id[..].copy_from_slice(&v[0..32]);
@@ -1117,12 +1117,12 @@ impl MsgEncodable for UpdateFee {
 impl MsgDecodable for ChannelReestablish {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+2*8+33 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
 
                let your_last_per_commitment_secret = if v.len() > 32+2*8+33 {
                        if v.len() < 32+2*8+33 + 32 {
-                               return Err(DecodeError::WrongLength);
+                               return Err(DecodeError::ShortRead);
                        }
                        let mut inner_array = [0; 32];
                        inner_array.copy_from_slice(&v[48..48+32]);
@@ -1165,7 +1165,7 @@ impl MsgEncodable for ChannelReestablish {
 impl MsgDecodable for AnnouncementSignatures {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+8+64*2 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let secp_ctx = Secp256k1::without_caps();
                let mut channel_id = [0; 32];
@@ -1194,7 +1194,7 @@ impl MsgDecodable for UnsignedNodeAnnouncement {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                let features = GlobalFeatures::decode(&v[..])?;
                if v.len() < features.encoded_len() + 4 + 33 + 3 + 32 + 2 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let start = features.encoded_len();
 
@@ -1206,22 +1206,23 @@ impl MsgDecodable for UnsignedNodeAnnouncement {
 
                let addrlen = byte_utils::slice_to_be16(&v[start + 72..start + 74]) as usize;
                if v.len() < start + 74 + addrlen {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
+               let addr_read_limit = start + 74 + addrlen;
 
                let mut addresses = Vec::with_capacity(4);
                let mut read_pos = start + 74;
                loop {
-                       if v.len() <= read_pos { break; }
+                       if addr_read_limit <= read_pos { break; }
                        match v[read_pos] {
                                0 => { read_pos += 1; },
                                1 => {
-                                       if v.len() < read_pos + 1 + 6 {
-                                               return Err(DecodeError::WrongLength);
-                                       }
                                        if addresses.len() > 0 {
                                                return Err(DecodeError::ExtraAddressesPerType);
                                        }
+                                       if addr_read_limit < read_pos + 1 + 6 {
+                                               return Err(DecodeError::BadLengthDescriptor);
+                                       }
                                        let mut addr = [0; 4];
                                        addr.copy_from_slice(&v[read_pos + 1..read_pos + 5]);
                                        addresses.push(NetAddress::IPv4 {
@@ -1231,12 +1232,12 @@ impl MsgDecodable for UnsignedNodeAnnouncement {
                                        read_pos += 1 + 6;
                                },
                                2 => {
-                                       if v.len() < read_pos + 1 + 18 {
-                                               return Err(DecodeError::WrongLength);
-                                       }
                                        if addresses.len() > 1 || (addresses.len() == 1 && addresses[0].get_id() != 1) {
                                                return Err(DecodeError::ExtraAddressesPerType);
                                        }
+                                       if addr_read_limit < read_pos + 1 + 18 {
+                                               return Err(DecodeError::BadLengthDescriptor);
+                                       }
                                        let mut addr = [0; 16];
                                        addr.copy_from_slice(&v[read_pos + 1..read_pos + 17]);
                                        addresses.push(NetAddress::IPv6 {
@@ -1246,12 +1247,12 @@ impl MsgDecodable for UnsignedNodeAnnouncement {
                                        read_pos += 1 + 18;
                                },
                                3 => {
-                                       if v.len() < read_pos + 1 + 12 {
-                                               return Err(DecodeError::WrongLength);
-                                       }
                                        if addresses.len() > 2 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 2) {
                                                return Err(DecodeError::ExtraAddressesPerType);
                                        }
+                                       if addr_read_limit < read_pos + 1 + 12 {
+                                               return Err(DecodeError::BadLengthDescriptor);
+                                       }
                                        let mut addr = [0; 10];
                                        addr.copy_from_slice(&v[read_pos + 1..read_pos + 11]);
                                        addresses.push(NetAddress::OnionV2 {
@@ -1261,12 +1262,12 @@ impl MsgDecodable for UnsignedNodeAnnouncement {
                                        read_pos += 1 + 12;
                                },
                                4 => {
-                                       if v.len() < read_pos + 1 + 37 {
-                                               return Err(DecodeError::WrongLength);
-                                       }
                                        if addresses.len() > 3 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 3) {
                                                return Err(DecodeError::ExtraAddressesPerType);
                                        }
+                                       if addr_read_limit < read_pos + 1 + 37 {
+                                               return Err(DecodeError::BadLengthDescriptor);
+                                       }
                                        let mut ed25519_pubkey = [0; 32];
                                        ed25519_pubkey.copy_from_slice(&v[read_pos + 1..read_pos + 33]);
                                        addresses.push(NetAddress::OnionV3 {
@@ -1340,7 +1341,7 @@ impl MsgEncodable for UnsignedNodeAnnouncement {
 impl MsgDecodable for NodeAnnouncement {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 64 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let secp_ctx = Secp256k1::without_caps();
                Ok(Self {
@@ -1364,7 +1365,7 @@ impl MsgDecodable for UnsignedChannelAnnouncement {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                let features = GlobalFeatures::decode(&v[..])?;
                if v.len() < features.encoded_len() + 32 + 8 + 33*4 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let start = features.encoded_len();
                let secp_ctx = Secp256k1::without_caps();
@@ -1397,7 +1398,7 @@ impl MsgEncodable for UnsignedChannelAnnouncement {
 impl MsgDecodable for ChannelAnnouncement {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 64*4 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let secp_ctx = Secp256k1::without_caps();
                Ok(Self {
@@ -1426,7 +1427,7 @@ impl MsgEncodable for ChannelAnnouncement {
 impl MsgDecodable for UnsignedChannelUpdate {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32+8+4+2+2+8+4+4 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                Ok(Self {
                        chain_hash: deserialize(&v[0..32]).unwrap(),
@@ -1458,7 +1459,7 @@ impl MsgEncodable for UnsignedChannelUpdate {
 impl MsgDecodable for ChannelUpdate {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 128 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let secp_ctx = Secp256k1::without_caps();
                Ok(Self {
@@ -1479,7 +1480,7 @@ impl MsgEncodable for ChannelUpdate {
 impl MsgDecodable for OnionRealm0HopData {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                Ok(OnionRealm0HopData {
                        short_channel_id: byte_utils::slice_to_be64(&v[0..8]),
@@ -1502,7 +1503,7 @@ impl MsgEncodable for OnionRealm0HopData {
 impl MsgDecodable for OnionHopData {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 65 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let realm = v[0];
                if realm != 0 {
@@ -1530,7 +1531,7 @@ impl MsgEncodable for OnionHopData {
 impl MsgDecodable for OnionPacket {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 1+33+20*65+32 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let mut hop_data = [0; 20*65];
                hop_data.copy_from_slice(&v[34..1334]);
@@ -1559,15 +1560,15 @@ impl MsgEncodable for OnionPacket {
 impl MsgDecodable for DecodedOnionErrorPacket {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 32 + 4 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let failuremsg_len = byte_utils::slice_to_be16(&v[32..34]) as usize;
                if v.len() < 32 + 4 + failuremsg_len {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let padding_len = byte_utils::slice_to_be16(&v[34 + failuremsg_len..]) as usize;
                if v.len() < 32 + 4 + failuremsg_len + padding_len {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
 
                let mut hmac = [0; 32];
@@ -1594,11 +1595,11 @@ impl MsgEncodable for DecodedOnionErrorPacket {
 impl MsgDecodable for OnionErrorPacket {
        fn decode(v: &[u8]) -> Result<Self, DecodeError> {
                if v.len() < 2 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let len = byte_utils::slice_to_be16(&v[0..2]) as usize;
                if v.len() < 2 + len {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                Ok(Self {
                        data: v[2..len+2].to_vec(),
@@ -1626,11 +1627,11 @@ impl MsgEncodable for ErrorMessage {
 impl MsgDecodable for ErrorMessage {
        fn decode(v: &[u8]) -> Result<Self,DecodeError> {
                if v.len() < 34 {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let len = byte_utils::slice_to_be16(&v[32..34]);
                if v.len() < 34 + len as usize {
-                       return Err(DecodeError::WrongLength);
+                       return Err(DecodeError::ShortRead);
                }
                let data = match String::from_utf8(v[34..34 + len as usize].to_vec()) {
                        Ok(s) => s,
index ad7ae675117f935561979b1ff97277dde314b48f..f30eb7912792df24f1dd6bcedae86cadb2e8f69a 100644 (file)
@@ -428,39 +428,42 @@ impl Router {
                        ( $chan_id: expr, $dest_node_id: expr, $directional_info: expr, $starting_fee_msat: expr ) => {
                                //TODO: Explore simply adding fee to hit htlc_minimum_msat
                                if $starting_fee_msat as u64 + final_value_msat > $directional_info.htlc_minimum_msat {
-                                       let new_fee = $directional_info.fee_base_msat as u64 + ($starting_fee_msat + final_value_msat) * ($directional_info.fee_proportional_millionths as u64) / 1000000;
-                                       let mut total_fee = $starting_fee_msat as u64;
-                                       let mut hm_entry = dist.entry(&$directional_info.src_node_id);
-                                       let old_entry = hm_entry.or_insert_with(|| {
-                                               let node = network.nodes.get(&$directional_info.src_node_id).unwrap();
-                                               (u64::max_value(),
-                                                       node.lowest_inbound_channel_fee_base_msat as u64,
-                                                       node.lowest_inbound_channel_fee_proportional_millionths as u64,
-                                                       RouteHop {
-                                                               pubkey: PublicKey::new(),
-                                                               short_channel_id: 0,
-                                                               fee_msat: 0,
-                                                               cltv_expiry_delta: 0,
-                                               })
-                                       });
-                                       if $directional_info.src_node_id != network.our_node_id {
-                                               // Ignore new_fee for channel-from-us as we assume all channels-from-us
-                                               // will have the same effective-fee
-                                               total_fee += new_fee;
-                                               total_fee += old_entry.2 * (final_value_msat + total_fee) / 1000000 + old_entry.1;
-                                       }
-                                       let new_graph_node = RouteGraphNode {
-                                               pubkey: $directional_info.src_node_id,
-                                               lowest_fee_to_peer_through_node: total_fee,
-                                       };
-                                       if old_entry.0 > total_fee {
-                                               targets.push(new_graph_node);
-                                               old_entry.0 = total_fee;
-                                               old_entry.3 = RouteHop {
-                                                       pubkey: $dest_node_id.clone(),
-                                                       short_channel_id: $chan_id.clone(),
-                                                       fee_msat: new_fee, // This field is ignored on the last-hop anyway
-                                                       cltv_expiry_delta: $directional_info.cltv_expiry_delta as u32,
+                                       let proportional_fee_millions = ($starting_fee_msat + final_value_msat).checked_mul($directional_info.fee_proportional_millionths as u64);
+                                       if let Some(proportional_fee) = proportional_fee_millions {
+                                               let new_fee = $directional_info.fee_base_msat as u64 + proportional_fee / 1000000;
+                                               let mut total_fee = $starting_fee_msat as u64;
+                                               let mut hm_entry = dist.entry(&$directional_info.src_node_id);
+                                               let old_entry = hm_entry.or_insert_with(|| {
+                                                       let node = network.nodes.get(&$directional_info.src_node_id).unwrap();
+                                                       (u64::max_value(),
+                                                               node.lowest_inbound_channel_fee_base_msat as u64,
+                                                               node.lowest_inbound_channel_fee_proportional_millionths as u64,
+                                                               RouteHop {
+                                                                       pubkey: PublicKey::new(),
+                                                                       short_channel_id: 0,
+                                                                       fee_msat: 0,
+                                                                       cltv_expiry_delta: 0,
+                                                       })
+                                               });
+                                               if $directional_info.src_node_id != network.our_node_id {
+                                                       // Ignore new_fee for channel-from-us as we assume all channels-from-us
+                                                       // will have the same effective-fee
+                                                       total_fee += new_fee;
+                                                       total_fee += old_entry.2 * (final_value_msat + total_fee) / 1000000 + old_entry.1;
+                                               }
+                                               let new_graph_node = RouteGraphNode {
+                                                       pubkey: $directional_info.src_node_id,
+                                                       lowest_fee_to_peer_through_node: total_fee,
+                                               };
+                                               if old_entry.0 > total_fee {
+                                                       targets.push(new_graph_node);
+                                                       old_entry.0 = total_fee;
+                                                       old_entry.3 = RouteHop {
+                                                               pubkey: $dest_node_id.clone(),
+                                                               short_channel_id: $chan_id.clone(),
+                                                               fee_msat: new_fee, // This field is ignored on the last-hop anyway
+                                                               cltv_expiry_delta: $directional_info.cltv_expiry_delta as u32,
+                                                       }
                                                }
                                        }
                                }