Merge pull request #1799 from TheBlueMatt/2022-10-heap-nerdsnipe
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Wed, 25 Jan 2023 23:19:13 +0000 (23:19 +0000)
committerGitHub <noreply@github.com>
Wed, 25 Jan 2023 23:19:13 +0000 (23:19 +0000)
Router Optimizations

.github/workflows/build.yml
fuzz/src/bin/gen_target.sh
fuzz/src/bin/indexedmap_target.rs [new file with mode: 0644]
fuzz/src/bin/msg_channel_details_target.rs [new file with mode: 0644]
fuzz/src/indexedmap.rs [new file with mode: 0644]
fuzz/src/lib.rs
fuzz/targets.h
lightning/src/routing/gossip.rs
lightning/src/routing/router.rs
lightning/src/util/indexed_map.rs [new file with mode: 0644]
lightning/src/util/mod.rs

index f27a2ccf863fa0348295a3f0c56810fa48280eb7..340b7f898d9ded31ba5b035e71fcf0ca83924419 100644 (file)
@@ -242,19 +242,19 @@ jobs:
         id: cache-graph
         uses: actions/cache@v3
         with:
-          path: lightning/net_graph-2021-05-31.bin
-          key: ldk-net_graph-v0.0.15-2021-05-31.bin
+          path: lightning/net_graph-2023-01-18.bin
+          key: ldk-net_graph-v0.0.113-2023-01-18.bin
       - name: Fetch routing graph snapshot
         if: steps.cache-graph.outputs.cache-hit != 'true'
         run: |
-          curl --verbose -L -o lightning/net_graph-2021-05-31.bin https://bitcoin.ninja/ldk-net_graph-v0.0.15-2021-05-31.bin
-          echo "Sha sum: $(sha256sum lightning/net_graph-2021-05-31.bin | awk '{ print $1 }')"
-          if [ "$(sha256sum lightning/net_graph-2021-05-31.bin | awk '{ print $1 }')" != "${EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM}" ]; then
+          curl --verbose -L -o lightning/net_graph-2023-01-18.bin https://bitcoin.ninja/ldk-net_graph-v0.0.113-2023-01-18.bin
+          echo "Sha sum: $(sha256sum lightning/net_graph-2023-01-18.bin | awk '{ print $1 }')"
+          if [ "$(sha256sum lightning/net_graph-2023-01-18.bin | awk '{ print $1 }')" != "${EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM}" ]; then
             echo "Bad hash"
             exit 1
           fi
         env:
-          EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM: 05a5361278f68ee2afd086cc04a1f927a63924be451f3221d380533acfacc303
+          EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM: da6066f2bddcddbe7d8a6debbd53545697137b310bbb8c4911bc8c81fc5ff48c
       - name: Fetch rapid graph sync reference input
         run: |
           curl --verbose -L -o lightning-rapid-gossip-sync/res/full_graph.lngossip https://bitcoin.ninja/ldk-compressed_graph-285cb27df79-2022-07-21.bin
index 95e65695eb868c798aba81b1d3684ba870154521..fa29540f96b35cae0ad5e43e1b8ef485d620520b 100755 (executable)
@@ -14,6 +14,7 @@ GEN_TEST peer_crypt
 GEN_TEST process_network_graph
 GEN_TEST router
 GEN_TEST zbase32
+GEN_TEST indexedmap
 
 GEN_TEST msg_accept_channel msg_targets::
 GEN_TEST msg_announcement_signatures msg_targets::
diff --git a/fuzz/src/bin/indexedmap_target.rs b/fuzz/src/bin/indexedmap_target.rs
new file mode 100644 (file)
index 0000000..238566d
--- /dev/null
@@ -0,0 +1,113 @@
+// This file is Copyright its original authors, visible in version control
+// history.
+//
+// This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
+// or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
+// You may not use this file except in accordance with one or both of these
+// licenses.
+
+// This file is auto-generated by gen_target.sh based on target_template.txt
+// To modify it, modify target_template.txt and run gen_target.sh instead.
+
+#![cfg_attr(feature = "libfuzzer_fuzz", no_main)]
+
+#[cfg(not(fuzzing))]
+compile_error!("Fuzz targets need cfg=fuzzing");
+
+extern crate lightning_fuzz;
+use lightning_fuzz::indexedmap::*;
+
+#[cfg(feature = "afl")]
+#[macro_use] extern crate afl;
+#[cfg(feature = "afl")]
+fn main() {
+       fuzz!(|data| {
+               indexedmap_run(data.as_ptr(), data.len());
+       });
+}
+
+#[cfg(feature = "honggfuzz")]
+#[macro_use] extern crate honggfuzz;
+#[cfg(feature = "honggfuzz")]
+fn main() {
+       loop {
+               fuzz!(|data| {
+                       indexedmap_run(data.as_ptr(), data.len());
+               });
+       }
+}
+
+#[cfg(feature = "libfuzzer_fuzz")]
+#[macro_use] extern crate libfuzzer_sys;
+#[cfg(feature = "libfuzzer_fuzz")]
+fuzz_target!(|data: &[u8]| {
+       indexedmap_run(data.as_ptr(), data.len());
+});
+
+#[cfg(feature = "stdin_fuzz")]
+fn main() {
+       use std::io::Read;
+
+       let mut data = Vec::with_capacity(8192);
+       std::io::stdin().read_to_end(&mut data).unwrap();
+       indexedmap_run(data.as_ptr(), data.len());
+}
+
+#[test]
+fn run_test_cases() {
+       use std::fs;
+       use std::io::Read;
+       use lightning_fuzz::utils::test_logger::StringBuffer;
+
+       use std::sync::{atomic, Arc};
+       {
+               let data: Vec<u8> = vec![0];
+               indexedmap_run(data.as_ptr(), data.len());
+       }
+       let mut threads = Vec::new();
+       let threads_running = Arc::new(atomic::AtomicUsize::new(0));
+       if let Ok(tests) = fs::read_dir("test_cases/indexedmap") {
+               for test in tests {
+                       let mut data: Vec<u8> = Vec::new();
+                       let path = test.unwrap().path();
+                       fs::File::open(&path).unwrap().read_to_end(&mut data).unwrap();
+                       threads_running.fetch_add(1, atomic::Ordering::AcqRel);
+
+                       let thread_count_ref = Arc::clone(&threads_running);
+                       let main_thread_ref = std::thread::current();
+                       threads.push((path.file_name().unwrap().to_str().unwrap().to_string(),
+                               std::thread::spawn(move || {
+                                       let string_logger = StringBuffer::new();
+
+                                       let panic_logger = string_logger.clone();
+                                       let res = if ::std::panic::catch_unwind(move || {
+                                               indexedmap_test(&data, panic_logger);
+                                       }).is_err() {
+                                               Some(string_logger.into_string())
+                                       } else { None };
+                                       thread_count_ref.fetch_sub(1, atomic::Ordering::AcqRel);
+                                       main_thread_ref.unpark();
+                                       res
+                               })
+                       ));
+                       while threads_running.load(atomic::Ordering::Acquire) > 32 {
+                               std::thread::park();
+                       }
+               }
+       }
+       let mut failed_outputs = Vec::new();
+       for (test, thread) in threads.drain(..) {
+               if let Some(output) = thread.join().unwrap() {
+                       println!("\nOutput of {}:\n{}\n", test, output);
+                       failed_outputs.push(test);
+               }
+       }
+       if !failed_outputs.is_empty() {
+               println!("Test cases which failed: ");
+               for case in failed_outputs {
+                       println!("{}", case);
+               }
+               panic!();
+       }
+}
diff --git a/fuzz/src/bin/msg_channel_details_target.rs b/fuzz/src/bin/msg_channel_details_target.rs
new file mode 100644 (file)
index 0000000..cb5021a
--- /dev/null
@@ -0,0 +1,113 @@
+// This file is Copyright its original authors, visible in version control
+// history.
+//
+// This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
+// or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
+// You may not use this file except in accordance with one or both of these
+// licenses.
+
+// This file is auto-generated by gen_target.sh based on target_template.txt
+// To modify it, modify target_template.txt and run gen_target.sh instead.
+
+#![cfg_attr(feature = "libfuzzer_fuzz", no_main)]
+
+#[cfg(not(fuzzing))]
+compile_error!("Fuzz targets need cfg=fuzzing");
+
+extern crate lightning_fuzz;
+use lightning_fuzz::msg_targets::msg_channel_details::*;
+
+#[cfg(feature = "afl")]
+#[macro_use] extern crate afl;
+#[cfg(feature = "afl")]
+fn main() {
+       fuzz!(|data| {
+               msg_channel_details_run(data.as_ptr(), data.len());
+       });
+}
+
+#[cfg(feature = "honggfuzz")]
+#[macro_use] extern crate honggfuzz;
+#[cfg(feature = "honggfuzz")]
+fn main() {
+       loop {
+               fuzz!(|data| {
+                       msg_channel_details_run(data.as_ptr(), data.len());
+               });
+       }
+}
+
+#[cfg(feature = "libfuzzer_fuzz")]
+#[macro_use] extern crate libfuzzer_sys;
+#[cfg(feature = "libfuzzer_fuzz")]
+fuzz_target!(|data: &[u8]| {
+       msg_channel_details_run(data.as_ptr(), data.len());
+});
+
+#[cfg(feature = "stdin_fuzz")]
+fn main() {
+       use std::io::Read;
+
+       let mut data = Vec::with_capacity(8192);
+       std::io::stdin().read_to_end(&mut data).unwrap();
+       msg_channel_details_run(data.as_ptr(), data.len());
+}
+
+#[test]
+fn run_test_cases() {
+       use std::fs;
+       use std::io::Read;
+       use lightning_fuzz::utils::test_logger::StringBuffer;
+
+       use std::sync::{atomic, Arc};
+       {
+               let data: Vec<u8> = vec![0];
+               msg_channel_details_run(data.as_ptr(), data.len());
+       }
+       let mut threads = Vec::new();
+       let threads_running = Arc::new(atomic::AtomicUsize::new(0));
+       if let Ok(tests) = fs::read_dir("test_cases/msg_channel_details") {
+               for test in tests {
+                       let mut data: Vec<u8> = Vec::new();
+                       let path = test.unwrap().path();
+                       fs::File::open(&path).unwrap().read_to_end(&mut data).unwrap();
+                       threads_running.fetch_add(1, atomic::Ordering::AcqRel);
+
+                       let thread_count_ref = Arc::clone(&threads_running);
+                       let main_thread_ref = std::thread::current();
+                       threads.push((path.file_name().unwrap().to_str().unwrap().to_string(),
+                               std::thread::spawn(move || {
+                                       let string_logger = StringBuffer::new();
+
+                                       let panic_logger = string_logger.clone();
+                                       let res = if ::std::panic::catch_unwind(move || {
+                                               msg_channel_details_test(&data, panic_logger);
+                                       }).is_err() {
+                                               Some(string_logger.into_string())
+                                       } else { None };
+                                       thread_count_ref.fetch_sub(1, atomic::Ordering::AcqRel);
+                                       main_thread_ref.unpark();
+                                       res
+                               })
+                       ));
+                       while threads_running.load(atomic::Ordering::Acquire) > 32 {
+                               std::thread::park();
+                       }
+               }
+       }
+       let mut failed_outputs = Vec::new();
+       for (test, thread) in threads.drain(..) {
+               if let Some(output) = thread.join().unwrap() {
+                       println!("\nOutput of {}:\n{}\n", test, output);
+                       failed_outputs.push(test);
+               }
+       }
+       if !failed_outputs.is_empty() {
+               println!("Test cases which failed: ");
+               for case in failed_outputs {
+                       println!("{}", case);
+               }
+               panic!();
+       }
+}
diff --git a/fuzz/src/indexedmap.rs b/fuzz/src/indexedmap.rs
new file mode 100644 (file)
index 0000000..795d617
--- /dev/null
@@ -0,0 +1,137 @@
+// This file is Copyright its original authors, visible in version control
+// history.
+//
+// This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
+// or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
+// You may not use this file except in accordance with one or both of these
+// licenses.
+
+use lightning::util::indexed_map::{IndexedMap, self};
+use std::collections::{BTreeMap, btree_map};
+use hashbrown::HashSet;
+
+use crate::utils::test_logger;
+
+fn check_eq(btree: &BTreeMap<u8, u8>, indexed: &IndexedMap<u8, u8>) {
+       assert_eq!(btree.len(), indexed.len());
+       assert_eq!(btree.is_empty(), indexed.is_empty());
+
+       let mut btree_clone = btree.clone();
+       assert!(btree_clone == *btree);
+       let mut indexed_clone = indexed.clone();
+       assert!(indexed_clone == *indexed);
+
+       for k in 0..=255 {
+               assert_eq!(btree.contains_key(&k), indexed.contains_key(&k));
+               assert_eq!(btree.get(&k), indexed.get(&k));
+
+               let btree_entry = btree_clone.entry(k);
+               let indexed_entry = indexed_clone.entry(k);
+               match btree_entry {
+                       btree_map::Entry::Occupied(mut bo) => {
+                               if let indexed_map::Entry::Occupied(mut io) = indexed_entry {
+                                       assert_eq!(bo.get(), io.get());
+                                       assert_eq!(bo.get_mut(), io.get_mut());
+                               } else { panic!(); }
+                       },
+                       btree_map::Entry::Vacant(_) => {
+                               if let indexed_map::Entry::Vacant(_) = indexed_entry {
+                               } else { panic!(); }
+                       }
+               }
+       }
+
+       const STRIDE: u8 = 16;
+       for k in 0..=255/STRIDE {
+               let lower_bound = k * STRIDE;
+               let upper_bound = lower_bound + (STRIDE - 1);
+               let mut btree_iter = btree.range(lower_bound..=upper_bound);
+               let mut indexed_iter = indexed.range(lower_bound..=upper_bound);
+               loop {
+                       let b_v = btree_iter.next();
+                       let i_v = indexed_iter.next();
+                       assert_eq!(b_v, i_v);
+                       if b_v.is_none() { break; }
+               }
+       }
+
+       let mut key_set = HashSet::with_capacity(256);
+       for k in indexed.unordered_keys() {
+               assert!(key_set.insert(*k));
+               assert!(btree.contains_key(k));
+       }
+       assert_eq!(key_set.len(), btree.len());
+
+       key_set.clear();
+       for (k, v) in indexed.unordered_iter() {
+               assert!(key_set.insert(*k));
+               assert_eq!(btree.get(k).unwrap(), v);
+       }
+       assert_eq!(key_set.len(), btree.len());
+
+       key_set.clear();
+       for (k, v) in indexed_clone.unordered_iter_mut() {
+               assert!(key_set.insert(*k));
+               assert_eq!(btree.get(k).unwrap(), v);
+       }
+       assert_eq!(key_set.len(), btree.len());
+}
+
+#[inline]
+pub fn do_test(data: &[u8]) {
+       if data.len() % 2 != 0 { return; }
+       let mut btree = BTreeMap::new();
+       let mut indexed = IndexedMap::new();
+
+       // Read in k-v pairs from the input and insert them into the maps then check that the maps are
+       // equivalent in every way we can read them.
+       for tuple in data.windows(2) {
+               let prev_value_b = btree.insert(tuple[0], tuple[1]);
+               let prev_value_i = indexed.insert(tuple[0], tuple[1]);
+               assert_eq!(prev_value_b, prev_value_i);
+       }
+       check_eq(&btree, &indexed);
+
+       // Now, modify the maps in all the ways we have to do so, checking that the maps remain
+       // equivalent as we go.
+       for (k, v) in indexed.unordered_iter_mut() {
+               *v = *k;
+               *btree.get_mut(k).unwrap() = *k;
+       }
+       check_eq(&btree, &indexed);
+
+       for k in 0..=255 {
+               match btree.entry(k) {
+                       btree_map::Entry::Occupied(mut bo) => {
+                               if let indexed_map::Entry::Occupied(mut io) = indexed.entry(k) {
+                                       if k < 64 {
+                                               *io.get_mut() ^= 0xff;
+                                               *bo.get_mut() ^= 0xff;
+                                       } else if k < 128 {
+                                               *io.into_mut() ^= 0xff;
+                                               *bo.get_mut() ^= 0xff;
+                                       } else {
+                                               assert_eq!(bo.remove_entry(), io.remove_entry());
+                                       }
+                               } else { panic!(); }
+                       },
+                       btree_map::Entry::Vacant(bv) => {
+                               if let indexed_map::Entry::Vacant(iv) = indexed.entry(k) {
+                                       bv.insert(k);
+                                       iv.insert(k);
+                               } else { panic!(); }
+                       },
+               }
+       }
+       check_eq(&btree, &indexed);
+}
+
+pub fn indexedmap_test<Out: test_logger::Output>(data: &[u8], _out: Out) {
+       do_test(data);
+}
+
+#[no_mangle]
+pub extern "C" fn indexedmap_run(data: *const u8, datalen: usize) {
+       do_test(unsafe { std::slice::from_raw_parts(data, datalen) });
+}
index 2238a9702a9563a78e59ef8010f6e816f282b4b1..462307d55b42a6e977276065825cc325c8c3c807 100644 (file)
@@ -17,6 +17,7 @@ pub mod utils;
 pub mod chanmon_deser;
 pub mod chanmon_consistency;
 pub mod full_stack;
+pub mod indexedmap;
 pub mod onion_message;
 pub mod peer_crypt;
 pub mod process_network_graph;
index cff3f9bdbb52dd87e5f2cfbd68525768e109963e..5bfee07dafbb149e4db2a8262a22209355ec11ba 100644 (file)
@@ -7,6 +7,7 @@ void peer_crypt_run(const unsigned char* data, size_t data_len);
 void process_network_graph_run(const unsigned char* data, size_t data_len);
 void router_run(const unsigned char* data, size_t data_len);
 void zbase32_run(const unsigned char* data, size_t data_len);
+void indexedmap_run(const unsigned char* data, size_t data_len);
 void msg_accept_channel_run(const unsigned char* data, size_t data_len);
 void msg_announcement_signatures_run(const unsigned char* data, size_t data_len);
 void msg_channel_reestablish_run(const unsigned char* data, size_t data_len);
index 39f07914a61a11541b055fe9cabafc6a7dcce43d..950782d46f665bae2e0dd3baff42633456e3ec2d 100644 (file)
@@ -32,11 +32,11 @@ use crate::util::logger::{Logger, Level};
 use crate::util::events::{MessageSendEvent, MessageSendEventsProvider};
 use crate::util::scid_utils::{block_from_scid, scid_from_parts, MAX_SCID_BLOCK};
 use crate::util::string::PrintableString;
+use crate::util::indexed_map::{IndexedMap, Entry as IndexedMapEntry};
 
 use crate::io;
 use crate::io_extras::{copy, sink};
 use crate::prelude::*;
-use alloc::collections::{BTreeMap, btree_map::Entry as BtreeEntry};
 use core::{cmp, fmt};
 use crate::sync::{RwLock, RwLockReadGuard};
 #[cfg(feature = "std")]
@@ -133,8 +133,8 @@ pub struct NetworkGraph<L: Deref> where L::Target: Logger {
        genesis_hash: BlockHash,
        logger: L,
        // Lock order: channels -> nodes
-       channels: RwLock<BTreeMap<u64, ChannelInfo>>,
-       nodes: RwLock<BTreeMap<NodeId, NodeInfo>>,
+       channels: RwLock<IndexedMap<u64, ChannelInfo>>,
+       nodes: RwLock<IndexedMap<NodeId, NodeInfo>>,
        // Lock order: removed_channels -> removed_nodes
        //
        // NOTE: In the following `removed_*` maps, we use seconds since UNIX epoch to track time instead
@@ -158,8 +158,8 @@ pub struct NetworkGraph<L: Deref> where L::Target: Logger {
 
 /// A read-only view of [`NetworkGraph`].
 pub struct ReadOnlyNetworkGraph<'a> {
-       channels: RwLockReadGuard<'a, BTreeMap<u64, ChannelInfo>>,
-       nodes: RwLockReadGuard<'a, BTreeMap<NodeId, NodeInfo>>,
+       channels: RwLockReadGuard<'a, IndexedMap<u64, ChannelInfo>>,
+       nodes: RwLockReadGuard<'a, IndexedMap<NodeId, NodeInfo>>,
 }
 
 /// Update to the [`NetworkGraph`] based on payment failure information conveyed via the Onion
@@ -1054,10 +1054,6 @@ impl Readable for NodeAlias {
 pub struct NodeInfo {
        /// All valid channels a node has announced
        pub channels: Vec<u64>,
-       /// Lowest fees enabling routing via any of the enabled, known channels to a node.
-       /// The two fields (flat and proportional fee) are independent,
-       /// meaning they don't have to refer to the same channel.
-       pub lowest_inbound_channel_fees: Option<RoutingFees>,
        /// More information about a node from node_announcement.
        /// Optional because we store a Node entry after learning about it from
        /// a channel announcement, but before receiving a node announcement.
@@ -1066,8 +1062,8 @@ pub struct NodeInfo {
 
 impl fmt::Display for NodeInfo {
        fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
-               write!(f, "lowest_inbound_channel_fees: {:?}, channels: {:?}, announcement_info: {:?}",
-                  self.lowest_inbound_channel_fees, &self.channels[..], self.announcement_info)?;
+               write!(f, " channels: {:?}, announcement_info: {:?}",
+                       &self.channels[..], self.announcement_info)?;
                Ok(())
        }
 }
@@ -1075,7 +1071,7 @@ impl fmt::Display for NodeInfo {
 impl Writeable for NodeInfo {
        fn write<W: crate::util::ser::Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
                write_tlv_fields!(writer, {
-                       (0, self.lowest_inbound_channel_fees, option),
+                       // Note that older versions of LDK wrote the lowest inbound fees here at type 0
                        (2, self.announcement_info, option),
                        (4, self.channels, vec_type),
                });
@@ -1103,18 +1099,22 @@ impl MaybeReadable for NodeAnnouncementInfoDeserWrapper {
 
 impl Readable for NodeInfo {
        fn read<R: io::Read>(reader: &mut R) -> Result<Self, DecodeError> {
-               _init_tlv_field_var!(lowest_inbound_channel_fees, option);
+               // Historically, we tracked the lowest inbound fees for any node in order to use it as an
+               // A* heuristic when routing. Sadly, these days many, many nodes have at least one channel
+               // with zero inbound fees, causing that heuristic to provide little gain. Worse, because it
+               // requires additional complexity and lookups during routing, it ends up being a
+               // performance loss. Thus, we simply ignore the old field here and no longer track it.
+               let mut _lowest_inbound_channel_fees: Option<RoutingFees> = None;
                let mut announcement_info_wrap: Option<NodeAnnouncementInfoDeserWrapper> = None;
                _init_tlv_field_var!(channels, vec_type);
 
                read_tlv_fields!(reader, {
-                       (0, lowest_inbound_channel_fees, option),
+                       (0, _lowest_inbound_channel_fees, option),
                        (2, announcement_info_wrap, ignorable),
                        (4, channels, vec_type),
                });
 
                Ok(NodeInfo {
-                       lowest_inbound_channel_fees: _init_tlv_based_struct_field!(lowest_inbound_channel_fees, option),
                        announcement_info: announcement_info_wrap.map(|w| w.0),
                        channels: _init_tlv_based_struct_field!(channels, vec_type),
                })
@@ -1131,13 +1131,13 @@ impl<L: Deref> Writeable for NetworkGraph<L> where L::Target: Logger {
                self.genesis_hash.write(writer)?;
                let channels = self.channels.read().unwrap();
                (channels.len() as u64).write(writer)?;
-               for (ref chan_id, ref chan_info) in channels.iter() {
+               for (ref chan_id, ref chan_info) in channels.unordered_iter() {
                        (*chan_id).write(writer)?;
                        chan_info.write(writer)?;
                }
                let nodes = self.nodes.read().unwrap();
                (nodes.len() as u64).write(writer)?;
-               for (ref node_id, ref node_info) in nodes.iter() {
+               for (ref node_id, ref node_info) in nodes.unordered_iter() {
                        node_id.write(writer)?;
                        node_info.write(writer)?;
                }
@@ -1156,14 +1156,14 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
 
                let genesis_hash: BlockHash = Readable::read(reader)?;
                let channels_count: u64 = Readable::read(reader)?;
-               let mut channels = BTreeMap::new();
+               let mut channels = IndexedMap::new();
                for _ in 0..channels_count {
                        let chan_id: u64 = Readable::read(reader)?;
                        let chan_info = Readable::read(reader)?;
                        channels.insert(chan_id, chan_info);
                }
                let nodes_count: u64 = Readable::read(reader)?;
-               let mut nodes = BTreeMap::new();
+               let mut nodes = IndexedMap::new();
                for _ in 0..nodes_count {
                        let node_id = Readable::read(reader)?;
                        let node_info = Readable::read(reader)?;
@@ -1191,11 +1191,11 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
 impl<L: Deref> fmt::Display for NetworkGraph<L> where L::Target: Logger {
        fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
                writeln!(f, "Network map\n[Channels]")?;
-               for (key, val) in self.channels.read().unwrap().iter() {
+               for (key, val) in self.channels.read().unwrap().unordered_iter() {
                        writeln!(f, " {}: {}", key, val)?;
                }
                writeln!(f, "[Nodes]")?;
-               for (&node_id, val) in self.nodes.read().unwrap().iter() {
+               for (&node_id, val) in self.nodes.read().unwrap().unordered_iter() {
                        writeln!(f, " {}: {}", log_bytes!(node_id.as_slice()), val)?;
                }
                Ok(())
@@ -1218,8 +1218,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        secp_ctx: Secp256k1::verification_only(),
                        genesis_hash,
                        logger,
-                       channels: RwLock::new(BTreeMap::new()),
-                       nodes: RwLock::new(BTreeMap::new()),
+                       channels: RwLock::new(IndexedMap::new()),
+                       nodes: RwLock::new(IndexedMap::new()),
                        last_rapid_gossip_sync_timestamp: Mutex::new(None),
                        removed_channels: Mutex::new(HashMap::new()),
                        removed_nodes: Mutex::new(HashMap::new()),
@@ -1252,7 +1252,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
        /// purposes.
        #[cfg(test)]
        pub fn clear_nodes_announcement_info(&self) {
-               for node in self.nodes.write().unwrap().iter_mut() {
+               for node in self.nodes.write().unwrap().unordered_iter_mut() {
                        node.1.announcement_info = None;
                }
        }
@@ -1382,7 +1382,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                let node_id_b = channel_info.node_two.clone();
 
                match channels.entry(short_channel_id) {
-                       BtreeEntry::Occupied(mut entry) => {
+                       IndexedMapEntry::Occupied(mut entry) => {
                                //TODO: because asking the blockchain if short_channel_id is valid is only optional
                                //in the blockchain API, we need to handle it smartly here, though it's unclear
                                //exactly how...
@@ -1401,20 +1401,19 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                        return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreDuplicateGossip});
                                }
                        },
-                       BtreeEntry::Vacant(entry) => {
+                       IndexedMapEntry::Vacant(entry) => {
                                entry.insert(channel_info);
                        }
                };
 
                for current_node_id in [node_id_a, node_id_b].iter() {
                        match nodes.entry(current_node_id.clone()) {
-                               BtreeEntry::Occupied(node_entry) => {
+                               IndexedMapEntry::Occupied(node_entry) => {
                                        node_entry.into_mut().channels.push(short_channel_id);
                                },
-                               BtreeEntry::Vacant(node_entry) => {
+                               IndexedMapEntry::Vacant(node_entry) => {
                                        node_entry.insert(NodeInfo {
                                                channels: vec!(short_channel_id),
-                                               lowest_inbound_channel_fees: None,
                                                announcement_info: None,
                                        });
                                }
@@ -1586,7 +1585,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        for scid in node.channels.iter() {
                                if let Some(chan_info) = channels.remove(scid) {
                                        let other_node_id = if node_id == chan_info.node_one { chan_info.node_two } else { chan_info.node_one };
-                                       if let BtreeEntry::Occupied(mut other_node_entry) = nodes.entry(other_node_id) {
+                                       if let IndexedMapEntry::Occupied(mut other_node_entry) = nodes.entry(other_node_id) {
                                                other_node_entry.get_mut().channels.retain(|chan_id| {
                                                        *scid != *chan_id
                                                });
@@ -1645,7 +1644,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                // Sadly BTreeMap::retain was only stabilized in 1.53 so we can't switch to it for some
                // time.
                let mut scids_to_remove = Vec::new();
-               for (scid, info) in channels.iter_mut() {
+               for (scid, info) in channels.unordered_iter_mut() {
                        if info.one_to_two.is_some() && info.one_to_two.as_ref().unwrap().last_update < min_time_unix {
                                info.one_to_two = None;
                        }
@@ -1715,9 +1714,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
        }
 
        fn update_channel_intern(&self, msg: &msgs::UnsignedChannelUpdate, full_msg: Option<&msgs::ChannelUpdate>, sig: Option<&secp256k1::ecdsa::Signature>) -> Result<(), LightningError> {
-               let dest_node_id;
                let chan_enabled = msg.flags & (1 << 1) != (1 << 1);
-               let chan_was_enabled;
 
                #[cfg(all(feature = "std", not(test), not(feature = "_test_utils")))]
                {
@@ -1765,9 +1762,6 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                                        } else if existing_chan_info.last_update == msg.timestamp {
                                                                return Err(LightningError{err: "Update had same timestamp as last processed update".to_owned(), action: ErrorAction::IgnoreDuplicateGossip});
                                                        }
-                                                       chan_was_enabled = existing_chan_info.enabled;
-                                               } else {
-                                                       chan_was_enabled = false;
                                                }
                                        }
                                }
@@ -1795,7 +1789,6 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
 
                                let msg_hash = hash_to_message!(&Sha256dHash::hash(&msg.encode()[..])[..]);
                                if msg.flags & 1 == 1 {
-                                       dest_node_id = channel.node_one.clone();
                                        check_update_latest!(channel.two_to_one);
                                        if let Some(sig) = sig {
                                                secp_verify_sig!(self.secp_ctx, &msg_hash, &sig, &PublicKey::from_slice(channel.node_two.as_slice()).map_err(|_| LightningError{
@@ -1805,7 +1798,6 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                        }
                                        channel.two_to_one = get_new_channel_info!();
                                } else {
-                                       dest_node_id = channel.node_two.clone();
                                        check_update_latest!(channel.one_to_two);
                                        if let Some(sig) = sig {
                                                secp_verify_sig!(self.secp_ctx, &msg_hash, &sig, &PublicKey::from_slice(channel.node_one.as_slice()).map_err(|_| LightningError{
@@ -1818,51 +1810,13 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        }
                }
 
-               let mut nodes = self.nodes.write().unwrap();
-               if chan_enabled {
-                       let node = nodes.get_mut(&dest_node_id).unwrap();
-                       let mut base_msat = msg.fee_base_msat;
-                       let mut proportional_millionths = msg.fee_proportional_millionths;
-                       if let Some(fees) = node.lowest_inbound_channel_fees {
-                               base_msat = cmp::min(base_msat, fees.base_msat);
-                               proportional_millionths = cmp::min(proportional_millionths, fees.proportional_millionths);
-                       }
-                       node.lowest_inbound_channel_fees = Some(RoutingFees {
-                               base_msat,
-                               proportional_millionths
-                       });
-               } else if chan_was_enabled {
-                       let node = nodes.get_mut(&dest_node_id).unwrap();
-                       let mut lowest_inbound_channel_fees = None;
-
-                       for chan_id in node.channels.iter() {
-                               let chan = channels.get(chan_id).unwrap();
-                               let chan_info_opt;
-                               if chan.node_one == dest_node_id {
-                                       chan_info_opt = chan.two_to_one.as_ref();
-                               } else {
-                                       chan_info_opt = chan.one_to_two.as_ref();
-                               }
-                               if let Some(chan_info) = chan_info_opt {
-                                       if chan_info.enabled {
-                                               let fees = lowest_inbound_channel_fees.get_or_insert(RoutingFees {
-                                                       base_msat: u32::max_value(), proportional_millionths: u32::max_value() });
-                                               fees.base_msat = cmp::min(fees.base_msat, chan_info.fees.base_msat);
-                                               fees.proportional_millionths = cmp::min(fees.proportional_millionths, chan_info.fees.proportional_millionths);
-                                       }
-                               }
-                       }
-
-                       node.lowest_inbound_channel_fees = lowest_inbound_channel_fees;
-               }
-
                Ok(())
        }
 
-       fn remove_channel_in_nodes(nodes: &mut BTreeMap<NodeId, NodeInfo>, chan: &ChannelInfo, short_channel_id: u64) {
+       fn remove_channel_in_nodes(nodes: &mut IndexedMap<NodeId, NodeInfo>, chan: &ChannelInfo, short_channel_id: u64) {
                macro_rules! remove_from_node {
                        ($node_id: expr) => {
-                               if let BtreeEntry::Occupied(mut entry) = nodes.entry($node_id) {
+                               if let IndexedMapEntry::Occupied(mut entry) = nodes.entry($node_id) {
                                        entry.get_mut().channels.retain(|chan_id| {
                                                short_channel_id != *chan_id
                                        });
@@ -1883,8 +1837,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
 impl ReadOnlyNetworkGraph<'_> {
        /// Returns all known valid channels' short ids along with announced channel info.
        ///
-       /// (C-not exported) because we have no mapping for `BTreeMap`s
-       pub fn channels(&self) -> &BTreeMap<u64, ChannelInfo> {
+       /// (C-not exported) because we don't want to return lifetime'd references
+       pub fn channels(&self) -> &IndexedMap<u64, ChannelInfo> {
                &*self.channels
        }
 
@@ -1896,13 +1850,13 @@ impl ReadOnlyNetworkGraph<'_> {
        #[cfg(c_bindings)] // Non-bindings users should use `channels`
        /// Returns the list of channels in the graph
        pub fn list_channels(&self) -> Vec<u64> {
-               self.channels.keys().map(|c| *c).collect()
+               self.channels.unordered_keys().map(|c| *c).collect()
        }
 
        /// Returns all known nodes' public keys along with announced node info.
        ///
-       /// (C-not exported) because we have no mapping for `BTreeMap`s
-       pub fn nodes(&self) -> &BTreeMap<NodeId, NodeInfo> {
+       /// (C-not exported) because we don't want to return lifetime'd references
+       pub fn nodes(&self) -> &IndexedMap<NodeId, NodeInfo> {
                &*self.nodes
        }
 
@@ -1914,7 +1868,7 @@ impl ReadOnlyNetworkGraph<'_> {
        #[cfg(c_bindings)] // Non-bindings users should use `nodes`
        /// Returns the list of nodes in the graph
        pub fn list_nodes(&self) -> Vec<NodeId> {
-               self.nodes.keys().map(|n| *n).collect()
+               self.nodes.unordered_keys().map(|n| *n).collect()
        }
 
        /// Get network addresses by node id.
@@ -3276,7 +3230,6 @@ mod tests {
                // 2. Check we can read a NodeInfo anyways, but set the NodeAnnouncementInfo to None if invalid
                let valid_node_info = NodeInfo {
                        channels: Vec::new(),
-                       lowest_inbound_channel_fees: None,
                        announcement_info: Some(valid_node_ann_info),
                };
 
index c15b612d939b16d4b75e79c81b3077cd60cadb09..8543956ac657de2da7a1ac163478aa1fd5ba22b2 100644 (file)
@@ -582,7 +582,6 @@ impl_writeable_tlv_based!(RouteHintHop, {
 #[derive(Eq, PartialEq)]
 struct RouteGraphNode {
        node_id: NodeId,
-       lowest_fee_to_peer_through_node: u64,
        lowest_fee_to_node: u64,
        total_cltv_delta: u32,
        // The maximum value a yet-to-be-constructed payment path might flow through this node.
@@ -603,9 +602,9 @@ struct RouteGraphNode {
 
 impl cmp::Ord for RouteGraphNode {
        fn cmp(&self, other: &RouteGraphNode) -> cmp::Ordering {
-               let other_score = cmp::max(other.lowest_fee_to_peer_through_node, other.path_htlc_minimum_msat)
+               let other_score = cmp::max(other.lowest_fee_to_node, other.path_htlc_minimum_msat)
                        .saturating_add(other.path_penalty_msat);
-               let self_score = cmp::max(self.lowest_fee_to_peer_through_node, self.path_htlc_minimum_msat)
+               let self_score = cmp::max(self.lowest_fee_to_node, self.path_htlc_minimum_msat)
                        .saturating_add(self.path_penalty_msat);
                other_score.cmp(&self_score).then_with(|| other.node_id.cmp(&self.node_id))
        }
@@ -729,8 +728,6 @@ struct PathBuildingHop<'a> {
        candidate: CandidateRouteHop<'a>,
        fee_msat: u64,
 
-       /// Minimal fees required to route to the source node of the current hop via any of its inbound channels.
-       src_lowest_inbound_fees: RoutingFees,
        /// All the fees paid *after* this channel on the way to the destination
        next_hops_fee_msat: u64,
        /// Fee paid for the use of the current channel (see candidate.fees()).
@@ -888,18 +885,20 @@ impl<'a> PaymentPath<'a> {
        }
 }
 
+#[inline(always)]
+/// Calculate the fees required to route the given amount over a channel with the given fees.
 fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option<u64> {
-       let proportional_fee_millions =
-               amount_msat.checked_mul(channel_fees.proportional_millionths as u64);
-       if let Some(new_fee) = proportional_fee_millions.and_then(|part| {
-                       (channel_fees.base_msat as u64).checked_add(part / 1_000_000) }) {
+       amount_msat.checked_mul(channel_fees.proportional_millionths as u64)
+               .and_then(|part| (channel_fees.base_msat as u64).checked_add(part / 1_000_000))
+}
 
-               Some(new_fee)
-       } else {
-               // This function may be (indirectly) called without any verification,
-               // with channel_fees provided by a caller. We should handle it gracefully.
-               None
-       }
+#[inline(always)]
+/// Calculate the fees required to route the given amount over a channel with the given fees,
+/// saturating to [`u64::max_value`].
+fn compute_fees_saturating(amount_msat: u64, channel_fees: RoutingFees) -> u64 {
+       amount_msat.checked_mul(channel_fees.proportional_millionths as u64)
+               .map(|prop| prop / 1_000_000).unwrap_or(u64::max_value())
+               .saturating_add(channel_fees.base_msat as u64)
 }
 
 /// The default `features` we assume for a node in a route, when no `features` are known about that
@@ -1007,9 +1006,8 @@ where L::Target: Logger {
        // 8. If our maximum channel saturation limit caused us to pick two identical paths, combine
        //    them so that we're not sending two HTLCs along the same path.
 
-       // As for the actual search algorithm,
-       // we do a payee-to-payer pseudo-Dijkstra's sorting by each node's distance from the payee
-       // plus the minimum per-HTLC fee to get from it to another node (aka "shitty pseudo-A*").
+       // As for the actual search algorithm, we do a payee-to-payer Dijkstra's sorting by each node's
+       // distance from the payee
        //
        // We are not a faithful Dijkstra's implementation because we can change values which impact
        // earlier nodes while processing later nodes. Specifically, if we reach a channel with a lower
@@ -1044,10 +1042,6 @@ where L::Target: Logger {
        // runtime for little gain. Specifically, the current algorithm rather efficiently explores the
        // graph for candidate paths, calculating the maximum value which can realistically be sent at
        // the same time, remaining generic across different payment values.
-       //
-       // TODO: There are a few tweaks we could do, including possibly pre-calculating more stuff
-       // to use as the A* heuristic beyond just the cost to get one node further than the current
-       // one.
 
        let network_channels = network_graph.channels();
        let network_nodes = network_graph.nodes();
@@ -1097,7 +1091,7 @@ where L::Target: Logger {
                }
        }
 
-       // The main heap containing all candidate next-hops sorted by their score (max(A* fee,
+       // The main heap containing all candidate next-hops sorted by their score (max(fee,
        // htlc_minimum)). Ideally this would be a heap which allowed cheap score reduction instead of
        // adding duplicate entries when we find a better path to a given node.
        let mut targets: BinaryHeap<RouteGraphNode> = BinaryHeap::new();
@@ -1262,10 +1256,10 @@ where L::Target: Logger {
                                                // might violate htlc_minimum_msat on the hops which are next along the
                                                // payment path (upstream to the payee). To avoid that, we recompute
                                                // path fees knowing the final path contribution after constructing it.
-                                               let path_htlc_minimum_msat = compute_fees($next_hops_path_htlc_minimum_msat, $candidate.fees())
-                                                       .and_then(|fee_msat| fee_msat.checked_add($next_hops_path_htlc_minimum_msat))
-                                                       .map(|fee_msat| cmp::max(fee_msat, $candidate.htlc_minimum_msat()))
-                                                       .unwrap_or_else(|| u64::max_value());
+                                               let path_htlc_minimum_msat = cmp::max(
+                                                       compute_fees_saturating($next_hops_path_htlc_minimum_msat, $candidate.fees())
+                                                               .saturating_add($next_hops_path_htlc_minimum_msat),
+                                                       $candidate.htlc_minimum_msat());
                                                let hm_entry = dist.entry($src_node_id);
                                                let old_entry = hm_entry.or_insert_with(|| {
                                                        // If there was previously no known way to access the source node
@@ -1273,20 +1267,10 @@ where L::Target: Logger {
                                                        // semi-dummy record just to compute the fees to reach the source node.
                                                        // This will affect our decision on selecting short_channel_id
                                                        // as a way to reach the $dest_node_id.
-                                                       let mut fee_base_msat = 0;
-                                                       let mut fee_proportional_millionths = 0;
-                                                       if let Some(Some(fees)) = network_nodes.get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) {
-                                                               fee_base_msat = fees.base_msat;
-                                                               fee_proportional_millionths = fees.proportional_millionths;
-                                                       }
                                                        PathBuildingHop {
                                                                node_id: $dest_node_id.clone(),
                                                                candidate: $candidate.clone(),
                                                                fee_msat: 0,
-                                                               src_lowest_inbound_fees: RoutingFees {
-                                                                       base_msat: fee_base_msat,
-                                                                       proportional_millionths: fee_proportional_millionths,
-                                                               },
                                                                next_hops_fee_msat: u64::max_value(),
                                                                hop_use_fee_msat: u64::max_value(),
                                                                total_fee_msat: u64::max_value(),
@@ -1309,38 +1293,15 @@ where L::Target: Logger {
 
                                                if should_process {
                                                        let mut hop_use_fee_msat = 0;
-                                                       let mut total_fee_msat = $next_hops_fee_msat;
+                                                       let mut total_fee_msat: u64 = $next_hops_fee_msat;
 
                                                        // Ignore hop_use_fee_msat for channel-from-us as we assume all channels-from-us
                                                        // will have the same effective-fee
                                                        if $src_node_id != our_node_id {
-                                                               match compute_fees(amount_to_transfer_over_msat, $candidate.fees()) {
-                                                                       // max_value means we'll always fail
-                                                                       // the old_entry.total_fee_msat > total_fee_msat check
-                                                                       None => total_fee_msat = u64::max_value(),
-                                                                       Some(fee_msat) => {
-                                                                               hop_use_fee_msat = fee_msat;
-                                                                               total_fee_msat += hop_use_fee_msat;
-                                                                               // When calculating the lowest inbound fees to a node, we
-                                                                               // calculate fees here not based on the actual value we think
-                                                                               // will flow over this channel, but on the minimum value that
-                                                                               // we'll accept flowing over it. The minimum accepted value
-                                                                               // is a constant through each path collection run, ensuring
-                                                                               // consistent basis. Otherwise we may later find a
-                                                                               // different path to the source node that is more expensive,
-                                                                               // but which we consider to be cheaper because we are capacity
-                                                                               // constrained and the relative fee becomes lower.
-                                                                               match compute_fees(minimal_value_contribution_msat, old_entry.src_lowest_inbound_fees)
-                                                                                               .map(|a| a.checked_add(total_fee_msat)) {
-                                                                                       Some(Some(v)) => {
-                                                                                               total_fee_msat = v;
-                                                                                       },
-                                                                                       _ => {
-                                                                                               total_fee_msat = u64::max_value();
-                                                                                       }
-                                                                               };
-                                                                       }
-                                                               }
+                                                               // Note that `u64::max_value` means we'll always fail the
+                                                               // `old_entry.total_fee_msat > total_fee_msat` check below
+                                                               hop_use_fee_msat = compute_fees_saturating(amount_to_transfer_over_msat, $candidate.fees());
+                                                               total_fee_msat = total_fee_msat.saturating_add(hop_use_fee_msat);
                                                        }
 
                                                        let channel_usage = ChannelUsage {
@@ -1355,8 +1316,7 @@ where L::Target: Logger {
                                                                .saturating_add(channel_penalty_msat);
                                                        let new_graph_node = RouteGraphNode {
                                                                node_id: $src_node_id,
-                                                               lowest_fee_to_peer_through_node: total_fee_msat,
-                                                               lowest_fee_to_node: $next_hops_fee_msat as u64 + hop_use_fee_msat,
+                                                               lowest_fee_to_node: total_fee_msat,
                                                                total_cltv_delta: hop_total_cltv_delta,
                                                                value_contribution_msat,
                                                                path_htlc_minimum_msat,
@@ -5544,9 +5504,9 @@ mod tests {
                'load_endpoints: for _ in 0..10 {
                        loop {
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let src = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                let payment_params = PaymentParameters::from_node_id(dst);
                                let amt = seed as u64 % 200_000_000;
                                let params = ProbabilisticScoringParameters::default();
@@ -5582,9 +5542,9 @@ mod tests {
                'load_endpoints: for _ in 0..10 {
                        loop {
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let src = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                let payment_params = PaymentParameters::from_node_id(dst).with_features(channelmanager::provided_invoice_features(&config));
                                let amt = seed as u64 % 200_000_000;
                                let params = ProbabilisticScoringParameters::default();
@@ -5639,8 +5599,8 @@ pub(crate) mod bench_utils {
        use std::fs::File;
        /// Tries to open a network graph file, or panics with a URL to fetch it.
        pub(crate) fn get_route_file() -> Result<std::fs::File, &'static str> {
-               let res = File::open("net_graph-2021-05-31.bin") // By default we're run in RL/lightning
-                       .or_else(|_| File::open("lightning/net_graph-2021-05-31.bin")) // We may be run manually in RL/
+               let res = File::open("net_graph-2023-01-18.bin") // By default we're run in RL/lightning
+                       .or_else(|_| File::open("lightning/net_graph-2023-01-18.bin")) // We may be run manually in RL/
                        .or_else(|_| { // Fall back to guessing based on the binary location
                                // path is likely something like .../rust-lightning/target/debug/deps/lightning-...
                                let mut path = std::env::current_exe().unwrap();
@@ -5649,11 +5609,11 @@ pub(crate) mod bench_utils {
                                path.pop(); // debug
                                path.pop(); // target
                                path.push("lightning");
-                               path.push("net_graph-2021-05-31.bin");
+                               path.push("net_graph-2023-01-18.bin");
                                eprintln!("{}", path.to_str().unwrap());
                                File::open(path)
                        })
-               .map_err(|_| "Please fetch https://bitcoin.ninja/ldk-net_graph-v0.0.15-2021-05-31.bin and place it at lightning/net_graph-2021-05-31.bin");
+               .map_err(|_| "Please fetch https://bitcoin.ninja/ldk-net_graph-v0.0.113-2023-01-18.bin and place it at lightning/net_graph-2023-01-18.bin");
                #[cfg(require_route_graph_test)]
                return Ok(res.unwrap());
                #[cfg(not(require_route_graph_test))]
@@ -5782,9 +5742,9 @@ mod benches {
                'load_endpoints: for _ in 0..150 {
                        loop {
                                seed *= 0xdeadbeef;
-                               let src = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let src = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed *= 0xdeadbeef;
-                               let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                let params = PaymentParameters::from_node_id(dst).with_features(features.clone());
                                let first_hop = first_hop(src);
                                let amt = seed as u64 % 1_000_000;
diff --git a/lightning/src/util/indexed_map.rs b/lightning/src/util/indexed_map.rs
new file mode 100644 (file)
index 0000000..cccbfe7
--- /dev/null
@@ -0,0 +1,203 @@
+//! This module has a map which can be iterated in a deterministic order. See the [`IndexedMap`].
+
+use crate::prelude::{HashMap, hash_map};
+use alloc::collections::{BTreeSet, btree_set};
+use core::hash::Hash;
+use core::cmp::Ord;
+use core::ops::RangeBounds;
+
+/// A map which can be iterated in a deterministic order.
+///
+/// This would traditionally be accomplished by simply using a [`BTreeMap`], however B-Trees
+/// generally have very slow lookups. Because we use a nodes+channels map while finding routes
+/// across the network graph, our network graph backing map must be as performant as possible.
+/// However, because peers expect to sync the network graph from us (and we need to support that
+/// without holding a lock on the graph for the duration of the sync or dumping the entire graph
+/// into our outbound message queue), we need an iterable map with a consistent iteration order we
+/// can jump to a starting point on.
+///
+/// Thus, we have a custom data structure here - its API mimics that of Rust's [`BTreeMap`], but is
+/// actually backed by a [`HashMap`], with some additional tracking to ensure we can iterate over
+/// keys in the order defined by [`Ord`].
+///
+/// [`BTreeMap`]: alloc::collections::BTreeMap
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct IndexedMap<K: Hash + Ord, V> {
+       map: HashMap<K, V>,
+       // TODO: Explore swapping this for a sorted vec (that is only sorted on first range() call)
+       keys: BTreeSet<K>,
+}
+
+impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
+       /// Constructs a new, empty map
+       pub fn new() -> Self {
+               Self {
+                       map: HashMap::new(),
+                       keys: BTreeSet::new(),
+               }
+       }
+
+       #[inline(always)]
+       /// Fetches the element with the given `key`, if one exists.
+       pub fn get(&self, key: &K) -> Option<&V> {
+               self.map.get(key)
+       }
+
+       /// Fetches a mutable reference to the element with the given `key`, if one exists.
+       pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
+               self.map.get_mut(key)
+       }
+
+       #[inline]
+       /// Returns true if an element with the given `key` exists in the map.
+       pub fn contains_key(&self, key: &K) -> bool {
+               self.map.contains_key(key)
+       }
+
+       /// Removes the element with the given `key`, returning it, if one exists.
+       pub fn remove(&mut self, key: &K) -> Option<V> {
+               let ret = self.map.remove(key);
+               if let Some(_) = ret {
+                       assert!(self.keys.remove(key), "map and keys must be consistent");
+               }
+               ret
+       }
+
+       /// Inserts the given `key`/`value` pair into the map, returning the element that was
+       /// previously stored at the given `key`, if one exists.
+       pub fn insert(&mut self, key: K, value: V) -> Option<V> {
+               let ret = self.map.insert(key.clone(), value);
+               if ret.is_none() {
+                       assert!(self.keys.insert(key), "map and keys must be consistent");
+               }
+               ret
+       }
+
+       /// Returns an [`Entry`] for the given `key` in the map, allowing access to the value.
+       pub fn entry(&mut self, key: K) -> Entry<'_, K, V> {
+               match self.map.entry(key.clone()) {
+                       hash_map::Entry::Vacant(entry) => {
+                               Entry::Vacant(VacantEntry {
+                                       underlying_entry: entry,
+                                       key,
+                                       keys: &mut self.keys,
+                               })
+                       },
+                       hash_map::Entry::Occupied(entry) => {
+                               Entry::Occupied(OccupiedEntry {
+                                       underlying_entry: entry,
+                                       keys: &mut self.keys,
+                               })
+                       }
+               }
+       }
+
+       /// Returns an iterator which iterates over the keys in the map, in a random order.
+       pub fn unordered_keys(&self) -> impl Iterator<Item = &K> {
+               self.map.keys()
+       }
+
+       /// Returns an iterator which iterates over the `key`/`value` pairs in a random order.
+       pub fn unordered_iter(&self) -> impl Iterator<Item = (&K, &V)> {
+               self.map.iter()
+       }
+
+       /// Returns an iterator which iterates over the `key`s and mutable references to `value`s in a
+       /// random order.
+       pub fn unordered_iter_mut(&mut self) -> impl Iterator<Item = (&K, &mut V)> {
+               self.map.iter_mut()
+       }
+
+       /// Returns an iterator which iterates over the `key`/`value` pairs in a given range.
+       pub fn range<R: RangeBounds<K>>(&self, range: R) -> Range<K, V> {
+               Range {
+                       inner_range: self.keys.range(range),
+                       map: &self.map,
+               }
+       }
+
+       /// Returns the number of `key`/`value` pairs in the map
+       pub fn len(&self) -> usize {
+               self.map.len()
+       }
+
+       /// Returns true if there are no elements in the map
+       pub fn is_empty(&self) -> bool {
+               self.map.is_empty()
+       }
+}
+
+/// An iterator over a range of values in an [`IndexedMap`]
+pub struct Range<'a, K: Hash + Ord, V> {
+       inner_range: btree_set::Range<'a, K>,
+       map: &'a HashMap<K, V>,
+}
+impl<'a, K: Hash + Ord, V: 'a> Iterator for Range<'a, K, V> {
+       type Item = (&'a K, &'a V);
+       fn next(&mut self) -> Option<(&'a K, &'a V)> {
+               self.inner_range.next().map(|k| {
+                       (k, self.map.get(k).expect("map and keys must be consistent"))
+               })
+       }
+}
+
+/// An [`Entry`] for a key which currently has no value
+pub struct VacantEntry<'a, K: Hash + Ord, V> {
+       #[cfg(feature = "hashbrown")]
+       underlying_entry: hash_map::VacantEntry<'a, K, V, hash_map::DefaultHashBuilder>,
+       #[cfg(not(feature = "hashbrown"))]
+       underlying_entry: hash_map::VacantEntry<'a, K, V>,
+       key: K,
+       keys: &'a mut BTreeSet<K>,
+}
+
+/// An [`Entry`] for an existing key-value pair
+pub struct OccupiedEntry<'a, K: Hash + Ord, V> {
+       #[cfg(feature = "hashbrown")]
+       underlying_entry: hash_map::OccupiedEntry<'a, K, V, hash_map::DefaultHashBuilder>,
+       #[cfg(not(feature = "hashbrown"))]
+       underlying_entry: hash_map::OccupiedEntry<'a, K, V>,
+       keys: &'a mut BTreeSet<K>,
+}
+
+/// A mutable reference to a position in the map. This can be used to reference, add, or update the
+/// value at a fixed key.
+pub enum Entry<'a, K: Hash + Ord, V> {
+       /// A mutable reference to a position within the map where there is no value.
+       Vacant(VacantEntry<'a, K, V>),
+       /// A mutable reference to a position within the map where there is currently a value.
+       Occupied(OccupiedEntry<'a, K, V>),
+}
+
+impl<'a, K: Hash + Ord, V> VacantEntry<'a, K, V> {
+       /// Insert a value into the position described by this entry.
+       pub fn insert(self, value: V) -> &'a mut V {
+               assert!(self.keys.insert(self.key), "map and keys must be consistent");
+               self.underlying_entry.insert(value)
+       }
+}
+
+impl<'a, K: Hash + Ord, V> OccupiedEntry<'a, K, V> {
+       /// Remove the value at the position described by this entry.
+       pub fn remove_entry(self) -> (K, V) {
+               let res = self.underlying_entry.remove_entry();
+               assert!(self.keys.remove(&res.0), "map and keys must be consistent");
+               res
+       }
+
+       /// Get a reference to the value at the position described by this entry.
+       pub fn get(&self) -> &V {
+               self.underlying_entry.get()
+       }
+
+       /// Get a mutable reference to the value at the position described by this entry.
+       pub fn get_mut(&mut self) -> &mut V {
+               self.underlying_entry.get_mut()
+       }
+
+       /// Consume this entry, returning a mutable reference to the value at the position described by
+       /// this entry.
+       pub fn into_mut(self) -> &'a mut V {
+               self.underlying_entry.into_mut()
+       }
+}
index 1d46865b6019b0158659ccab2c590c419416e36c..1673bd07f69b24d7af49f7d9c3604ea33f000422 100644 (file)
@@ -40,6 +40,8 @@ pub(crate) mod transaction_utils;
 pub(crate) mod scid_utils;
 pub(crate) mod time;
 
+pub mod indexed_map;
+
 /// Logging macro utilities.
 #[macro_use]
 pub(crate) mod macro_logger;