Refactor lightning-invoice/src/utils.rs to yield iterators
[rust-lightning] / lightning-invoice / src / utils.rs
index 8141f39559181e6d3b1bff1432798c5ec4ccb7c6..d8e7bf12726454cfec9ac0134e7f3ab9fd6a1406 100644 (file)
@@ -7,7 +7,7 @@ use bech32::ToBase32;
 use bitcoin_hashes::Hash;
 use lightning::chain;
 use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator};
-use lightning::chain::keysinterface::{Recipient, NodeSigner, SignerProvider, EntropySource};
+use lightning::sign::{Recipient, NodeSigner, SignerProvider, EntropySource};
 use lightning::ln::{PaymentHash, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, MIN_FINAL_CLTV_EXPIRY_DELTA};
 use lightning::ln::channelmanager::{PhantomRouteHints, MIN_CLTV_EXPIRY_DELTA};
@@ -18,6 +18,7 @@ use lightning::util::logger::Logger;
 use secp256k1::PublicKey;
 use core::ops::Deref;
 use core::time::Duration;
+use core::iter::Iterator;
 
 /// Utility to create an invoice that can be paid to one of multiple nodes, or a "phantom invoice."
 /// See [`PhantomKeysManager`] for more information on phantom node payments.
@@ -50,7 +51,7 @@ use core::time::Duration;
 /// invoices in its `sign_invoice` implementation ([`PhantomKeysManager`] satisfies this
 /// requirement).
 ///
-/// [`PhantomKeysManager`]: lightning::chain::keysinterface::PhantomKeysManager
+/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
 /// [`ChannelManager::get_phantom_route_hints`]: lightning::ln::channelmanager::ChannelManager::get_phantom_route_hints
 /// [`ChannelManager::create_inbound_payment`]: lightning::ln::channelmanager::ChannelManager::create_inbound_payment
 /// [`ChannelManager::create_inbound_payment_for_hash`]: lightning::ln::channelmanager::ChannelManager::create_inbound_payment_for_hash
@@ -107,7 +108,7 @@ where
 /// invoices in its `sign_invoice` implementation ([`PhantomKeysManager`] satisfies this
 /// requirement).
 ///
-/// [`PhantomKeysManager`]: lightning::chain::keysinterface::PhantomKeysManager
+/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
 /// [`ChannelManager::get_phantom_route_hints`]: lightning::ln::channelmanager::ChannelManager::get_phantom_route_hints
 /// [`ChannelManager::create_inbound_payment`]: lightning::ln::channelmanager::ChannelManager::create_inbound_payment
 /// [`ChannelManager::create_inbound_payment_for_hash`]: lightning::ln::channelmanager::ChannelManager::create_inbound_payment_for_hash
@@ -132,6 +133,8 @@ where
        )
 }
 
+const MAX_CHANNEL_HINTS: usize = 3;
+
 fn _create_phantom_invoice<ES: Deref, NS: Deref, L: Deref>(
        amt_msat: Option<u64>, payment_hash: Option<PaymentHash>, description: InvoiceDescription,
        invoice_expiry_delta_secs: u32, phantom_route_hints: Vec<PhantomRouteHints>, entropy_source: ES,
@@ -202,7 +205,8 @@ where
                invoice = invoice.amount_milli_satoshis(amt);
        }
 
-       for route_hint in select_phantom_hints(amt_msat, phantom_route_hints, logger) {
+
+       for route_hint in select_phantom_hints(amt_msat, phantom_route_hints, logger).take(MAX_CHANNEL_HINTS) {
                invoice = invoice.private_route(route_hint);
        }
 
@@ -227,38 +231,50 @@ where
 /// * Select up to three channels per node.
 /// * Select one hint from each node, up to three hints or until we run out of hints.
 ///
-/// [`PhantomKeysManager`]: lightning::chain::keysinterface::PhantomKeysManager
+/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
 fn select_phantom_hints<L: Deref>(amt_msat: Option<u64>, phantom_route_hints: Vec<PhantomRouteHints>,
-       logger: L) -> Vec<RouteHint>
+       logger: L) -> impl Iterator<Item = RouteHint>
 where
        L::Target: Logger,
 {
-       let mut phantom_hints: Vec<Vec<RouteHint>> = Vec::new();
+       let mut phantom_hints: Vec<_> = Vec::new();
 
        for PhantomRouteHints { channels, phantom_scid, real_node_pubkey } in phantom_route_hints {
                log_trace!(logger, "Generating phantom route hints for node {}",
                        log_pubkey!(real_node_pubkey));
-               let mut route_hints = sort_and_filter_channels(channels, amt_msat, &logger);
+               let route_hints = sort_and_filter_channels(channels, amt_msat, &logger);
 
                // If we have any public channel, the route hints from `sort_and_filter_channels` will be
                // empty. In that case we create a RouteHint on which we will push a single hop with the
                // phantom route into the invoice, and let the sender find the path to the `real_node_pubkey`
                // node by looking at our public channels.
-               if route_hints.is_empty() {
-                       route_hints.push(RouteHint(vec![]))
-               }
-               for route_hint in &mut route_hints {
-                       route_hint.0.push(RouteHintHop {
-                               src_node_id: real_node_pubkey,
-                               short_channel_id: phantom_scid,
-                               fees: RoutingFees {
-                                       base_msat: 0,
-                                       proportional_millionths: 0,
-                               },
-                               cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA,
-                               htlc_minimum_msat: None,
-                               htlc_maximum_msat: None,});
-               }
+               let empty_route_hints = route_hints.len() == 0;
+               let mut have_pushed_empty = false;
+               let route_hints = route_hints
+                       .chain(core::iter::from_fn(move || {
+                               if empty_route_hints && !have_pushed_empty {
+                                       // set flag of having handled the empty route_hints and ensure empty vector
+                                       // returned only once
+                                       have_pushed_empty = true;
+                                       Some(RouteHint(Vec::new()))
+                               } else {
+                                       None
+                               }
+                       }))
+                       .map(move |mut hint| {
+                               hint.0.push(RouteHintHop {
+                                       src_node_id: real_node_pubkey,
+                                       short_channel_id: phantom_scid,
+                                       fees: RoutingFees {
+                                               base_msat: 0,
+                                               proportional_millionths: 0,
+                                       },
+                                       cltv_expiry_delta: MIN_CLTV_EXPIRY_DELTA,
+                                       htlc_minimum_msat: None,
+                                       htlc_maximum_msat: None,
+                               });
+                               hint
+                       });
 
                phantom_hints.push(route_hints);
        }
@@ -267,29 +283,34 @@ where
        // the hints across our real nodes we add one hint from each in turn until no node has any hints
        // left (if one node has more hints than any other, these will accumulate at the end of the
        // vector).
-       let mut invoice_hints: Vec<RouteHint> = Vec::new();
-       let mut hint_idx = 0;
+       rotate_through_iterators(phantom_hints)
+}
 
-       loop {
-               let mut remaining_hints = false;
+/// Draw items iteratively from multiple iterators.  The items are retrieved by index and
+/// rotates through the iterators - first the zero index then the first index then second index, etc.
+fn rotate_through_iterators<T, I: Iterator<Item = T>>(mut vecs: Vec<I>) -> impl Iterator<Item = T> {
+       let mut iterations = 0;
 
-               for hints in phantom_hints.iter() {
-                       if invoice_hints.len() == 3 {
-                               return invoice_hints
+       core::iter::from_fn(move || {
+               let mut exhausted_iterators = 0;
+               loop {
+                       if vecs.is_empty() {
+                               return None;
                        }
-
-                       if hint_idx < hints.len() {
-                               invoice_hints.push(hints[hint_idx].clone());
-                               remaining_hints = true
+                       let next_idx = iterations % vecs.len();
+                       iterations += 1;
+                       if let Some(item) = vecs[next_idx].next() {
+                               return Some(item);
+                       }
+                       // exhausted_vectors increase when the "next_idx" vector is exhausted
+                       exhausted_iterators += 1;
+                       // The check for exhausted iterators gets reset to 0 after each yield of `Some()`
+                       // The loop will return None when all of the nested iterators are exhausted
+                       if exhausted_iterators == vecs.len() {
+                               return None;
                        }
                }
-
-               if !remaining_hints {
-                       return invoice_hints
-               }
-
-               hint_idx +=1;
-       }
+       })
 }
 
 #[cfg(feature = "std")]
@@ -575,8 +596,13 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has
 /// * Sorted by lowest inbound capacity if an online channel with the minimum amount requested exists,
 ///   otherwise sort by highest inbound capacity to give the payment the best chance of succeeding.
 fn sort_and_filter_channels<L: Deref>(
-       channels: Vec<ChannelDetails>, min_inbound_capacity_msat: Option<u64>, logger: &L
-) -> Vec<RouteHint> where L::Target: Logger {
+       channels: Vec<ChannelDetails>,
+       min_inbound_capacity_msat: Option<u64>,
+       logger: &L,
+) -> impl ExactSizeIterator<Item = RouteHint>
+where
+       L::Target: Logger,
+{
        let mut filtered_channels: HashMap<PublicKey, ChannelDetails> = HashMap::new();
        let min_inbound_capacity = min_inbound_capacity_msat.unwrap_or(0);
        let mut min_capacity_channel_exists = false;
@@ -584,6 +610,20 @@ fn sort_and_filter_channels<L: Deref>(
        let mut online_min_capacity_channel_exists = false;
        let mut has_pub_unconf_chan = false;
 
+       let route_hint_from_channel = |channel: ChannelDetails| {
+               let forwarding_info = channel.counterparty.forwarding_info.as_ref().unwrap();
+               RouteHint(vec![RouteHintHop {
+                       src_node_id: channel.counterparty.node_id,
+                       short_channel_id: channel.get_inbound_payment_scid().unwrap(),
+                       fees: RoutingFees {
+                               base_msat: forwarding_info.fee_base_msat,
+                               proportional_millionths: forwarding_info.fee_proportional_millionths,
+                       },
+                       cltv_expiry_delta: forwarding_info.cltv_expiry_delta,
+                       htlc_minimum_msat: channel.inbound_htlc_minimum_msat,
+                       htlc_maximum_msat: channel.inbound_htlc_maximum_msat,}])
+       };
+
        log_trace!(logger, "Considering {} channels for invoice route hints", channels.len());
        for channel in channels.into_iter().filter(|chan| chan.is_channel_ready) {
                if channel.get_inbound_payment_scid().is_none() || channel.counterparty.forwarding_info.is_none() {
@@ -602,7 +642,7 @@ fn sort_and_filter_channels<L: Deref>(
                                // look at the public channels instead.
                                log_trace!(logger, "Not including channels in invoice route hints on account of public channel {}",
                                        log_bytes!(channel.channel_id));
-                               return vec![]
+                               return vec![].into_iter().take(MAX_CHANNEL_HINTS).map(route_hint_from_channel);
                        }
                }
 
@@ -629,7 +669,7 @@ fn sort_and_filter_channels<L: Deref>(
                                // previous channel to avoid announcing non-public channels.
                                let new_now_public = channel.is_public && !entry.get().is_public;
                                // Decide whether we prefer the currently selected channel with the node to the new one,
-                               // based on their inbound capacity. 
+                               // based on their inbound capacity.
                                let prefer_current = prefer_current_channel(min_inbound_capacity_msat, current_max_capacity,
                                        channel.inbound_capacity_msat);
                                // If the public-ness of the channel has not changed (in which case simply defer to
@@ -662,19 +702,6 @@ fn sort_and_filter_channels<L: Deref>(
                }
        }
 
-       let route_hint_from_channel = |channel: ChannelDetails| {
-               let forwarding_info = channel.counterparty.forwarding_info.as_ref().unwrap();
-               RouteHint(vec![RouteHintHop {
-                       src_node_id: channel.counterparty.node_id,
-                       short_channel_id: channel.get_inbound_payment_scid().unwrap(),
-                       fees: RoutingFees {
-                               base_msat: forwarding_info.fee_base_msat,
-                               proportional_millionths: forwarding_info.fee_proportional_millionths,
-                       },
-                       cltv_expiry_delta: forwarding_info.cltv_expiry_delta,
-                       htlc_minimum_msat: channel.inbound_htlc_minimum_msat,
-                       htlc_maximum_msat: channel.inbound_htlc_maximum_msat,}])
-       };
        // If all channels are private, prefer to return route hints which have a higher capacity than
        // the payment value and where we're currently connected to the channel counterparty.
        // Even if we cannot satisfy both goals, always ensure we include *some* hints, preferring
@@ -724,7 +751,8 @@ fn sort_and_filter_channels<L: Deref>(
                        } else {
                                b.inbound_capacity_msat.cmp(&a.inbound_capacity_msat)
                        }});
-               eligible_channels.into_iter().take(3).map(route_hint_from_channel).collect::<Vec<RouteHint>>()
+
+               eligible_channels.into_iter().take(MAX_CHANNEL_HINTS).map(route_hint_from_channel)
 }
 
 /// prefer_current_channel chooses a channel to use for route hints between a currently selected and candidate
@@ -768,7 +796,7 @@ mod test {
        use crate::{Currency, Description, InvoiceDescription, SignOrCreationError, CreationError};
        use bitcoin_hashes::{Hash, sha256};
        use bitcoin_hashes::sha256::Hash as Sha256;
-       use lightning::chain::keysinterface::PhantomKeysManager;
+       use lightning::sign::PhantomKeysManager;
        use lightning::events::{MessageSendEvent, MessageSendEventsProvider, Event};
        use lightning::ln::{PaymentPreimage, PaymentHash};
        use lightning::ln::channelmanager::{PhantomRouteHints, MIN_FINAL_CLTV_EXPIRY_DELTA, PaymentId, RecipientOnionFields, Retry};
@@ -777,7 +805,7 @@ mod test {
        use lightning::routing::router::{PaymentParameters, RouteParameters};
        use lightning::util::test_utils;
        use lightning::util::config::UserConfig;
-       use crate::utils::create_invoice_from_channelmanager_and_duration_since_epoch;
+       use crate::utils::{create_invoice_from_channelmanager_and_duration_since_epoch, rotate_through_iterators};
        use std::collections::HashSet;
 
        #[test]
@@ -793,10 +821,10 @@ mod test {
 
                // Minimum set, prefer candidate channel over minimum + buffer.
                assert_eq!(crate::utils::prefer_current_channel(Some(100), 105, 125), false);
-               
+
                // Minimum set, both channels sufficient, prefer smaller current channel.
                assert_eq!(crate::utils::prefer_current_channel(Some(100), 115, 125), true);
-               
+
                // Minimum set, both channels sufficient, prefer smaller candidate channel.
                assert_eq!(crate::utils::prefer_current_channel(Some(100), 200, 160), false);
 
@@ -1886,4 +1914,111 @@ mod test {
                        _ => panic!(),
                }
        }
+
+       #[test]
+       fn test_rotate_through_iterators() {
+               // two nested vectors
+               let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec!["a1", "b1"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "a1", "b0", "b1", "c0"];
+               assert_eq!(expected, result);
+
+               // test single nested vector
+               let a = vec![vec!["a0", "b0", "c0"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "b0", "c0"];
+               assert_eq!(expected, result);
+
+               // test second vector with only one element
+               let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec!["a1"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "a1", "b0", "c0"];
+               assert_eq!(expected, result);
+
+               // test three nestend vectors
+               let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec!["a2"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "a1", "a2", "b1", "c1"];
+               assert_eq!(expected, result);
+
+               // test single nested vector with a single value
+               let a = vec![vec!["a0"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0"];
+               assert_eq!(expected, result);
+
+               // test single empty nested vector
+               let a:Vec<std::vec::IntoIter<&str>> = vec![vec![].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<&str>>();
+               let expected:Vec<&str> = vec![];
+
+               assert_eq!(expected, result);
+
+               // test first nested vector is empty
+               let a:Vec<std::vec::IntoIter<&str>>= vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<&str>>();
+
+               let expected = vec!["a1", "b1", "c1"];
+               assert_eq!(expected, result);
+
+               // test two empty vectors
+               let a:Vec<std::vec::IntoIter<&str>> = vec![vec![].into_iter(), vec![].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<&str>>();
+
+               let expected:Vec<&str> = vec![];
+               assert_eq!(expected, result);
+
+               // test an empty vector amongst other filled vectors
+               let a = vec![
+                       vec!["a0", "b0", "c0"].into_iter(),
+                       vec![].into_iter(),
+                       vec!["a1", "b1", "c1"].into_iter(),
+                       vec!["a2", "b2", "c2"].into_iter(),
+               ];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "a1", "a2", "b0", "b1", "b2", "c0", "c1", "c2"];
+               assert_eq!(expected, result);
+
+               // test a filled vector between two empty vectors
+               let a = vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec![].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a1", "b1", "c1"];
+               assert_eq!(expected, result);
+
+               // test an empty vector at the end of the vectors
+               let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec![].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "b0", "c0"];
+               assert_eq!(expected, result);
+
+               // test multiple empty vectors amongst multiple filled vectors
+               let a = vec![
+                       vec![].into_iter(),
+                       vec!["a1", "b1", "c1"].into_iter(),
+                       vec![].into_iter(),
+                       vec!["a3", "b3"].into_iter(),
+                       vec![].into_iter(),
+               ];
+
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a1", "a3", "b1", "b3", "c1"];
+               assert_eq!(expected, result);
+
+               // test one element in the first nested vectore and two elements in the second nested
+               // vector
+               let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "a1", "b1"];
+               assert_eq!(expected, result);
+       }
 }