Remove duplicate specification of features
[rust-lightning] / lightning / src / ln / peer_handler.rs
index fc6d1cbcda13883864382bb0cb07d35e6c3e1d32..2dadf2ec6149e565d65e155c6c844694fede4096 100644 (file)
@@ -355,10 +355,14 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
                                        InitSyncTracker::ChannelsSyncing(c) if c < 0xffff_ffff_ffff_ffff => {
                                                let steps = ((MSG_BUFF_SIZE - peer.pending_outbound_buffer.len() + 2) / 3) as u8;
                                                let all_messages = self.message_handler.route_handler.get_next_channel_announcements(c, steps);
-                                               for &(ref announce, ref update_a, ref update_b) in all_messages.iter() {
+                                               for &(ref announce, ref update_a_option, ref update_b_option) in all_messages.iter() {
                                                        encode_and_send_msg!(announce);
-                                                       encode_and_send_msg!(update_a);
-                                                       encode_and_send_msg!(update_b);
+                                                       if let &Some(ref update_a) = update_a_option {
+                                                               encode_and_send_msg!(update_a);
+                                                       }
+                                                       if let &Some(ref update_b) = update_b_option {
+                                                               encode_and_send_msg!(update_b);
+                                                       }
                                                        peer.sync_status = InitSyncTracker::ChannelsSyncing(announce.contents.short_channel_id + 1);
                                                }
                                                if all_messages.is_empty() || all_messages.len() != steps as usize {
@@ -544,9 +548,9 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
 
                                                                        peer.their_node_id = Some(their_node_id);
                                                                        insert_node_id!();
-                                                                       let mut features = InitFeatures::supported();
-                                                                       if self.message_handler.route_handler.should_request_full_sync(&peer.their_node_id.unwrap()) {
-                                                                               features.set_initial_routing_sync();
+                                                                       let mut features = InitFeatures::known();
+                                                                       if !self.message_handler.route_handler.should_request_full_sync(&peer.their_node_id.unwrap()) {
+                                                                               features.clear_initial_routing_sync();
                                                                        }
 
                                                                        let resp = msgs::Init { features };
@@ -638,9 +642,9 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
                                                                                                }
 
                                                                                                if !peer.outbound {
-                                                                                                       let mut features = InitFeatures::supported();
-                                                                                                       if self.message_handler.route_handler.should_request_full_sync(&peer.their_node_id.unwrap()) {
-                                                                                                               features.set_initial_routing_sync();
+                                                                                                       let mut features = InitFeatures::known();
+                                                                                                       if !self.message_handler.route_handler.should_request_full_sync(&peer.their_node_id.unwrap()) {
+                                                                                                               features.clear_initial_routing_sync();
                                                                                                        }
 
                                                                                                        let resp = msgs::Init { features };
@@ -1243,6 +1247,13 @@ mod tests {
                (fd_a.clone(), fd_b.clone())
        }
 
+       fn establish_connection_and_read_events<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler>) -> (FileDescriptor, FileDescriptor) {
+               let (mut fd_a, mut fd_b) = establish_connection(peer_a, peer_b);
+               assert_eq!(peer_b.read_event(&mut fd_b, &fd_a.outbound_data.lock().unwrap().split_off(0)).unwrap(), false);
+               assert_eq!(peer_a.read_event(&mut fd_a, &fd_b.outbound_data.lock().unwrap().split_off(0)).unwrap(), false);
+               (fd_a.clone(), fd_b.clone())
+       }
+
        #[test]
        fn test_disconnect_peer() {
                // Simple test which builds a network of PeerManager, connects and brings them to NoiseState::Finished and
@@ -1313,7 +1324,7 @@ mod tests {
                        Err(msgs::LightningError { err: "", action: msgs::ErrorAction::IgnoreError })
                }
                fn handle_htlc_fail_channel_update(&self, _update: &msgs::HTLCFailChannelUpdate) {}
-               fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(msgs::ChannelAnnouncement, msgs::ChannelUpdate,msgs::ChannelUpdate)> {
+               fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(msgs::ChannelAnnouncement, Option<msgs::ChannelUpdate>, Option<msgs::ChannelUpdate>)> {
                        let mut chan_anns = Vec::new();
                        const TOTAL_UPDS: u64 = 100;
                        let end: u64 =  min(starting_point + batch_amount as u64, TOTAL_UPDS - self.chan_anns_sent.load(Ordering::Acquire) as u64);
@@ -1322,7 +1333,7 @@ mod tests {
                                let chan_upd_2 = get_dummy_channel_update(i);
                                let chan_ann = get_dummy_channel_announcement(i);
 
-                               chan_anns.push((chan_ann, chan_upd_1, chan_upd_2));
+                               chan_anns.push((chan_ann, Some(chan_upd_1), Some(chan_upd_2)));
                        }
 
                        self.chan_anns_sent.fetch_add(chan_anns.len(), Ordering::AcqRel);
@@ -1347,7 +1358,7 @@ mod tests {
                let node_1_btckey = SecretKey::from_slice(&[40; 32]).unwrap();
                let node_2_btckey = SecretKey::from_slice(&[39; 32]).unwrap();
                let unsigned_ann = msgs::UnsignedChannelAnnouncement {
-                       features: ChannelFeatures::supported(),
+                       features: ChannelFeatures::known(),
                        chain_hash: genesis_block(network).header.bitcoin_hash(),
                        short_channel_id: short_chan_id,
                        node_id_1: PublicKey::from_secret_key(&secp_ctx, &node_1_privkey),
@@ -1417,4 +1428,47 @@ mod tests {
                assert_eq!(routing_handlers_concrete[1].clone().chan_upds_recvd.load(Ordering::Acquire), 100);
                assert_eq!(routing_handlers_concrete[1].clone().chan_anns_recvd.load(Ordering::Acquire), 50);
        }
+
+       #[test]
+       fn limit_initial_routing_sync_requests() {
+               // Inbound peer 0 requests initial_routing_sync, but outbound peer 1 does not.
+               {
+                       let chan_handlers = create_chan_handlers(2);
+                       let routing_handlers: Vec<Arc<msgs::RoutingMessageHandler>> = vec![
+                               Arc::new(test_utils::TestRoutingMessageHandler::new().set_request_full_sync()),
+                               Arc::new(test_utils::TestRoutingMessageHandler::new()),
+                       ];
+                       let peers = create_network(2, &chan_handlers, Some(&routing_handlers));
+                       let (fd_0_to_1, fd_1_to_0) = establish_connection_and_read_events(&peers[0], &peers[1]);
+
+                       let peer_0 = peers[0].peers.lock().unwrap();
+                       let peer_1 = peers[1].peers.lock().unwrap();
+
+                       let peer_0_features = peer_1.peers.get(&fd_1_to_0).unwrap().their_features.as_ref();
+                       let peer_1_features = peer_0.peers.get(&fd_0_to_1).unwrap().their_features.as_ref();
+
+                       assert!(peer_0_features.unwrap().initial_routing_sync());
+                       assert!(!peer_1_features.unwrap().initial_routing_sync());
+               }
+
+               // Outbound peer 1 requests initial_routing_sync, but inbound peer 0 does not.
+               {
+                       let chan_handlers = create_chan_handlers(2);
+                       let routing_handlers: Vec<Arc<msgs::RoutingMessageHandler>> = vec![
+                               Arc::new(test_utils::TestRoutingMessageHandler::new()),
+                               Arc::new(test_utils::TestRoutingMessageHandler::new().set_request_full_sync()),
+                       ];
+                       let peers = create_network(2, &chan_handlers, Some(&routing_handlers));
+                       let (fd_0_to_1, fd_1_to_0) = establish_connection_and_read_events(&peers[0], &peers[1]);
+
+                       let peer_0 = peers[0].peers.lock().unwrap();
+                       let peer_1 = peers[1].peers.lock().unwrap();
+
+                       let peer_0_features = peer_1.peers.get(&fd_1_to_0).unwrap().their_features.as_ref();
+                       let peer_1_features = peer_0.peers.get(&fd_0_to_1).unwrap().their_features.as_ref();
+
+                       assert!(!peer_0_features.unwrap().initial_routing_sync());
+                       assert!(peer_1_features.unwrap().initial_routing_sync());
+               }
+       }
 }