Adding rotate_through_iterators for select_phantom_hints refactor
[rust-lightning] / lightning-invoice / src / utils.rs
index fef7a78af145af8d2c8a7a6918c865e8dc06fd76..d9e1847f1a735442a751e35cda6b34cc6d8c1fca 100644 (file)
@@ -7,7 +7,7 @@ use bech32::ToBase32;
 use bitcoin_hashes::Hash;
 use lightning::chain;
 use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator};
-use lightning::chain::keysinterface::{Recipient, NodeSigner, SignerProvider, EntropySource};
+use lightning::sign::{Recipient, NodeSigner, SignerProvider, EntropySource};
 use lightning::ln::{PaymentHash, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, MIN_FINAL_CLTV_EXPIRY_DELTA};
 use lightning::ln::channelmanager::{PhantomRouteHints, MIN_CLTV_EXPIRY_DELTA};
@@ -18,6 +18,7 @@ use lightning::util::logger::Logger;
 use secp256k1::PublicKey;
 use core::ops::Deref;
 use core::time::Duration;
+use core::iter::Iterator;
 
 /// Utility to create an invoice that can be paid to one of multiple nodes, or a "phantom invoice."
 /// See [`PhantomKeysManager`] for more information on phantom node payments.
@@ -50,7 +51,7 @@ use core::time::Duration;
 /// invoices in its `sign_invoice` implementation ([`PhantomKeysManager`] satisfies this
 /// requirement).
 ///
-/// [`PhantomKeysManager`]: lightning::chain::keysinterface::PhantomKeysManager
+/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
 /// [`ChannelManager::get_phantom_route_hints`]: lightning::ln::channelmanager::ChannelManager::get_phantom_route_hints
 /// [`ChannelManager::create_inbound_payment`]: lightning::ln::channelmanager::ChannelManager::create_inbound_payment
 /// [`ChannelManager::create_inbound_payment_for_hash`]: lightning::ln::channelmanager::ChannelManager::create_inbound_payment_for_hash
@@ -107,7 +108,7 @@ where
 /// invoices in its `sign_invoice` implementation ([`PhantomKeysManager`] satisfies this
 /// requirement).
 ///
-/// [`PhantomKeysManager`]: lightning::chain::keysinterface::PhantomKeysManager
+/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
 /// [`ChannelManager::get_phantom_route_hints`]: lightning::ln::channelmanager::ChannelManager::get_phantom_route_hints
 /// [`ChannelManager::create_inbound_payment`]: lightning::ln::channelmanager::ChannelManager::create_inbound_payment
 /// [`ChannelManager::create_inbound_payment_for_hash`]: lightning::ln::channelmanager::ChannelManager::create_inbound_payment_for_hash
@@ -227,7 +228,7 @@ where
 /// * Select up to three channels per node.
 /// * Select one hint from each node, up to three hints or until we run out of hints.
 ///
-/// [`PhantomKeysManager`]: lightning::chain::keysinterface::PhantomKeysManager
+/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
 fn select_phantom_hints<L: Deref>(amt_msat: Option<u64>, phantom_route_hints: Vec<PhantomRouteHints>,
        logger: L) -> Vec<RouteHint>
 where
@@ -292,6 +293,33 @@ where
        }
 }
 
+/// Draw items iteratively from multiple iterators.  The items are retrieved by index and
+/// rotates through the iterators - first the zero index then the first index then second index, etc.
+fn rotate_through_iterators<T, I: Iterator<Item = T>>(mut vecs: Vec<I>) -> impl Iterator<Item = T> {
+       let mut iterations = 0;
+
+       core::iter::from_fn(move || {
+               let mut exhausted_iterators = 0;
+               loop {
+                       if vecs.is_empty() {
+                               return None;
+                       }
+                       let next_idx = iterations % vecs.len();
+                       iterations += 1;
+                       if let Some(item) = vecs[next_idx].next() {
+                               return Some(item);
+                       }
+                       // exhausted_vectors increase when the "next_idx" vector is exhausted
+                       exhausted_iterators += 1;
+                       // The check for exhausted iterators gets reset to 0 after each yield of `Some()`
+                       // The loop will return None when all of the nested iterators are exhausted
+                       if exhausted_iterators == vecs.len() {
+                               return None;
+                       }
+               }
+       })
+}
+
 #[cfg(feature = "std")]
 /// Utility to construct an invoice. Generally, unless you want to do something like a custom
 /// cltv_expiry, this is what you should be using to create an invoice. The reason being, this
@@ -629,7 +657,7 @@ fn sort_and_filter_channels<L: Deref>(
                                // previous channel to avoid announcing non-public channels.
                                let new_now_public = channel.is_public && !entry.get().is_public;
                                // Decide whether we prefer the currently selected channel with the node to the new one,
-                               // based on their inbound capacity. 
+                               // based on their inbound capacity.
                                let prefer_current = prefer_current_channel(min_inbound_capacity_msat, current_max_capacity,
                                        channel.inbound_capacity_msat);
                                // If the public-ness of the channel has not changed (in which case simply defer to
@@ -768,16 +796,16 @@ mod test {
        use crate::{Currency, Description, InvoiceDescription, SignOrCreationError, CreationError};
        use bitcoin_hashes::{Hash, sha256};
        use bitcoin_hashes::sha256::Hash as Sha256;
-       use lightning::chain::keysinterface::{EntropySource, PhantomKeysManager};
+       use lightning::sign::PhantomKeysManager;
        use lightning::events::{MessageSendEvent, MessageSendEventsProvider, Event};
        use lightning::ln::{PaymentPreimage, PaymentHash};
-       use lightning::ln::channelmanager::{PhantomRouteHints, MIN_FINAL_CLTV_EXPIRY_DELTA, PaymentId};
+       use lightning::ln::channelmanager::{PhantomRouteHints, MIN_FINAL_CLTV_EXPIRY_DELTA, PaymentId, RecipientOnionFields, Retry};
        use lightning::ln::functional_test_utils::*;
        use lightning::ln::msgs::ChannelMessageHandler;
-       use lightning::routing::router::{PaymentParameters, RouteParameters, find_route};
+       use lightning::routing::router::{PaymentParameters, RouteParameters};
        use lightning::util::test_utils;
        use lightning::util::config::UserConfig;
-       use crate::utils::create_invoice_from_channelmanager_and_duration_since_epoch;
+       use crate::utils::{create_invoice_from_channelmanager_and_duration_since_epoch, rotate_through_iterators};
        use std::collections::HashSet;
 
        #[test]
@@ -793,10 +821,10 @@ mod test {
 
                // Minimum set, prefer candidate channel over minimum + buffer.
                assert_eq!(crate::utils::prefer_current_channel(Some(100), 105, 125), false);
-               
+
                // Minimum set, both channels sufficient, prefer smaller current channel.
                assert_eq!(crate::utils::prefer_current_channel(Some(100), 115, 125), true);
-               
+
                // Minimum set, both channels sufficient, prefer smaller candidate channel.
                assert_eq!(crate::utils::prefer_current_channel(Some(100), 200, 160), false);
 
@@ -838,26 +866,18 @@ mod test {
 
                let payment_params = PaymentParameters::from_node_id(invoice.recover_payee_pub_key(),
                                invoice.min_final_cltv_expiry_delta() as u32)
-                       .with_features(invoice.features().unwrap().clone())
-                       .with_route_hints(invoice.route_hints());
+                       .with_bolt11_features(invoice.features().unwrap().clone()).unwrap()
+                       .with_route_hints(invoice.route_hints()).unwrap();
                let route_params = RouteParameters {
                        payment_params,
                        final_value_msat: invoice.amount_milli_satoshis().unwrap(),
                };
-               let first_hops = nodes[0].node.list_usable_channels();
-               let network_graph = &node_cfgs[0].network_graph;
-               let logger = test_utils::TestLogger::new();
-               let scorer = test_utils::TestScorer::new();
-               let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
-               let route = find_route(
-                       &nodes[0].node.get_our_node_id(), &route_params, network_graph,
-                       Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer, &random_seed_bytes
-               ).unwrap();
-
                let payment_event = {
                        let mut payment_hash = PaymentHash([0; 32]);
                        payment_hash.0.copy_from_slice(&invoice.payment_hash().as_ref()[0..32]);
-                       nodes[0].node.send_payment(&route, payment_hash, &Some(*invoice.payment_secret()), PaymentId(payment_hash.0)).unwrap();
+                       nodes[0].node.send_payment(payment_hash,
+                               RecipientOnionFields::secret_only(*invoice.payment_secret()),
+                               PaymentId(payment_hash.0), route_params, Retry::Attempts(0)).unwrap();
                        let mut added_monitors = nodes[0].chain_monitor.added_monitors.lock().unwrap();
                        assert_eq!(added_monitors.len(), 1);
                        added_monitors.clear();
@@ -1302,25 +1322,18 @@ mod test {
 
                let payment_params = PaymentParameters::from_node_id(invoice.recover_payee_pub_key(),
                                invoice.min_final_cltv_expiry_delta() as u32)
-                       .with_features(invoice.features().unwrap().clone())
-                       .with_route_hints(invoice.route_hints());
+                       .with_bolt11_features(invoice.features().unwrap().clone()).unwrap()
+                       .with_route_hints(invoice.route_hints()).unwrap();
                let params = RouteParameters {
                        payment_params,
                        final_value_msat: invoice.amount_milli_satoshis().unwrap(),
                };
-               let first_hops = nodes[0].node.list_usable_channels();
-               let network_graph = &node_cfgs[0].network_graph;
-               let logger = test_utils::TestLogger::new();
-               let scorer = test_utils::TestScorer::new();
-               let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
-               let route = find_route(
-                       &nodes[0].node.get_our_node_id(), &params, network_graph,
-                       Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer, &random_seed_bytes
-               ).unwrap();
                let (payment_event, fwd_idx) = {
                        let mut payment_hash = PaymentHash([0; 32]);
                        payment_hash.0.copy_from_slice(&invoice.payment_hash().as_ref()[0..32]);
-                       nodes[0].node.send_payment(&route, payment_hash, &Some(*invoice.payment_secret()), PaymentId(payment_hash.0)).unwrap();
+                       nodes[0].node.send_payment(payment_hash,
+                               RecipientOnionFields::secret_only(*invoice.payment_secret()),
+                               PaymentId(payment_hash.0), params, Retry::Attempts(0)).unwrap();
                        let mut added_monitors = nodes[0].chain_monitor.added_monitors.lock().unwrap();
                        assert_eq!(added_monitors.len(), 1);
                        added_monitors.clear();
@@ -1349,7 +1362,7 @@ mod test {
                nodes[fwd_idx].node.process_pending_htlc_forwards();
 
                let payment_preimage_opt = if user_generated_pmt_hash { None } else { Some(payment_preimage) };
-               expect_payment_claimable!(&nodes[fwd_idx], payment_hash, payment_secret, payment_amt, payment_preimage_opt, route.paths[0].last().unwrap().pubkey);
+               expect_payment_claimable!(&nodes[fwd_idx], payment_hash, payment_secret, payment_amt, payment_preimage_opt, invoice.recover_payee_pub_key());
                do_claim_payment_along_route(&nodes[0], &[&vec!(&nodes[fwd_idx])[..]], false, payment_preimage);
                let events = nodes[0].node.get_and_clear_pending_events();
                assert_eq!(events.len(), 2);
@@ -1901,4 +1914,111 @@ mod test {
                        _ => panic!(),
                }
        }
+
+       #[test]
+       fn test_rotate_through_iterators() {
+               // two nested vectors
+               let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec!["a1", "b1"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "a1", "b0", "b1", "c0"];
+               assert_eq!(expected, result);
+
+               // test single nested vector
+               let a = vec![vec!["a0", "b0", "c0"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "b0", "c0"];
+               assert_eq!(expected, result);
+
+               // test second vector with only one element
+               let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec!["a1"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "a1", "b0", "c0"];
+               assert_eq!(expected, result);
+
+               // test three nestend vectors
+               let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec!["a2"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "a1", "a2", "b1", "c1"];
+               assert_eq!(expected, result);
+
+               // test single nested vector with a single value
+               let a = vec![vec!["a0"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0"];
+               assert_eq!(expected, result);
+
+               // test single empty nested vector
+               let a:Vec<std::vec::IntoIter<&str>> = vec![vec![].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<&str>>();
+               let expected:Vec<&str> = vec![];
+
+               assert_eq!(expected, result);
+
+               // test first nested vector is empty
+               let a:Vec<std::vec::IntoIter<&str>>= vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<&str>>();
+
+               let expected = vec!["a1", "b1", "c1"];
+               assert_eq!(expected, result);
+
+               // test two empty vectors
+               let a:Vec<std::vec::IntoIter<&str>> = vec![vec![].into_iter(), vec![].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<&str>>();
+
+               let expected:Vec<&str> = vec![];
+               assert_eq!(expected, result);
+
+               // test an empty vector amongst other filled vectors
+               let a = vec![
+                       vec!["a0", "b0", "c0"].into_iter(),
+                       vec![].into_iter(),
+                       vec!["a1", "b1", "c1"].into_iter(),
+                       vec!["a2", "b2", "c2"].into_iter(),
+               ];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "a1", "a2", "b0", "b1", "b2", "c0", "c1", "c2"];
+               assert_eq!(expected, result);
+
+               // test a filled vector between two empty vectors
+               let a = vec![vec![].into_iter(), vec!["a1", "b1", "c1"].into_iter(), vec![].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a1", "b1", "c1"];
+               assert_eq!(expected, result);
+
+               // test an empty vector at the end of the vectors
+               let a = vec![vec!["a0", "b0", "c0"].into_iter(), vec![].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "b0", "c0"];
+               assert_eq!(expected, result);
+
+               // test multiple empty vectors amongst multiple filled vectors
+               let a = vec![
+                       vec![].into_iter(),
+                       vec!["a1", "b1", "c1"].into_iter(),
+                       vec![].into_iter(),
+                       vec!["a3", "b3"].into_iter(),
+                       vec![].into_iter(),
+               ];
+
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a1", "a3", "b1", "b3", "c1"];
+               assert_eq!(expected, result);
+
+               // test one element in the first nested vectore and two elements in the second nested
+               // vector
+               let a = vec![vec!["a0"].into_iter(), vec!["a1", "b1"].into_iter()];
+               let result = rotate_through_iterators(a).collect::<Vec<_>>();
+
+               let expected = vec!["a0", "a1", "b1"];
+               assert_eq!(expected, result);
+       }
 }