TestRouter: support checking only that route params are as expected.
[rust-lightning] / lightning / src / util / test_utils.rs
index 43518a3fd86cf2c9e8984e953d45b657c8b3bb68..9ff5d76ef8db0fe668d1033cc2ed11af44d16438 100644 (file)
@@ -112,7 +112,7 @@ pub struct TestRouter<'a> {
        >,
        //pub entropy_source: &'a RandomBytes,
        pub network_graph: Arc<NetworkGraph<&'a TestLogger>>,
-       pub next_routes: Mutex<VecDeque<(RouteParameters, Result<Route, LightningError>)>>,
+       pub next_routes: Mutex<VecDeque<(RouteParameters, Option<Result<Route, LightningError>>)>>,
        pub scorer: &'a RwLock<TestScorer>,
 }
 
@@ -132,7 +132,12 @@ impl<'a> TestRouter<'a> {
 
        pub fn expect_find_route(&self, query: RouteParameters, result: Result<Route, LightningError>) {
                let mut expected_routes = self.next_routes.lock().unwrap();
-               expected_routes.push_back((query, result));
+               expected_routes.push_back((query, Some(result)));
+       }
+
+       pub fn expect_find_route_query(&self, query: RouteParameters) {
+               let mut expected_routes = self.next_routes.lock().unwrap();
+               expected_routes.push_back((query, None));
        }
 }
 
@@ -145,63 +150,67 @@ impl<'a> Router for TestRouter<'a> {
                let next_route_opt = self.next_routes.lock().unwrap().pop_front();
                if let Some((find_route_query, find_route_res)) = next_route_opt {
                        assert_eq!(find_route_query, *params);
-                       if let Ok(ref route) = find_route_res {
-                               assert_eq!(route.route_params, Some(find_route_query));
-                               let scorer = self.scorer.read().unwrap();
-                               let scorer = ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs);
-                               for path in &route.paths {
-                                       let mut aggregate_msat = 0u64;
-                                       let mut prev_hop_node = payer;
-                                       for (idx, hop) in path.hops.iter().rev().enumerate() {
-                                               aggregate_msat += hop.fee_msat;
-                                               let usage = ChannelUsage {
-                                                       amount_msat: aggregate_msat,
-                                                       inflight_htlc_msat: 0,
-                                                       effective_capacity: EffectiveCapacity::Unknown,
-                                               };
-
-                                               if idx == path.hops.len() - 1 {
-                                                       if let Some(first_hops) = first_hops {
-                                                               if let Some(idx) = first_hops.iter().position(|h| h.get_outbound_payment_scid() == Some(hop.short_channel_id)) {
-                                                                       let node_id = NodeId::from_pubkey(payer);
-                                                                       let candidate = CandidateRouteHop::FirstHop(FirstHopCandidate {
-                                                                               details: first_hops[idx],
-                                                                               payer_node_id: &node_id,
-                                                                       });
-                                                                       scorer.channel_penalty_msat(&candidate, usage, &Default::default());
-                                                                       continue;
+                       if let Some(res) = find_route_res {
+                               if let Ok(ref route) = res {
+                                       assert_eq!(route.route_params, Some(find_route_query));
+                                       let scorer = self.scorer.read().unwrap();
+                                       let scorer = ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs);
+                                       for path in &route.paths {
+                                               let mut aggregate_msat = 0u64;
+                                               let mut prev_hop_node = payer;
+                                               for (idx, hop) in path.hops.iter().rev().enumerate() {
+                                                       aggregate_msat += hop.fee_msat;
+                                                       let usage = ChannelUsage {
+                                                               amount_msat: aggregate_msat,
+                                                               inflight_htlc_msat: 0,
+                                                               effective_capacity: EffectiveCapacity::Unknown,
+                                                       };
+
+                                                       if idx == path.hops.len() - 1 {
+                                                               if let Some(first_hops) = first_hops {
+                                                                       if let Some(idx) = first_hops.iter().position(|h| h.get_outbound_payment_scid() == Some(hop.short_channel_id)) {
+                                                                               let node_id = NodeId::from_pubkey(payer);
+                                                                               let candidate = CandidateRouteHop::FirstHop(FirstHopCandidate {
+                                                                                       details: first_hops[idx],
+                                                                                       payer_node_id: &node_id,
+                                                                               });
+                                                                               scorer.channel_penalty_msat(&candidate, usage, &Default::default());
+                                                                               continue;
+                                                                       }
                                                                }
                                                        }
+                                                       let network_graph = self.network_graph.read_only();
+                                                       if let Some(channel) = network_graph.channel(hop.short_channel_id) {
+                                                               let (directed, _) = channel.as_directed_to(&NodeId::from_pubkey(&hop.pubkey)).unwrap();
+                                                               let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate {
+                                                                       info: directed,
+                                                                       short_channel_id: hop.short_channel_id,
+                                                               });
+                                                               scorer.channel_penalty_msat(&candidate, usage, &Default::default());
+                                                       } else {
+                                                               let target_node_id = NodeId::from_pubkey(&hop.pubkey);
+                                                               let route_hint = RouteHintHop {
+                                                                       src_node_id: *prev_hop_node,
+                                                                       short_channel_id: hop.short_channel_id,
+                                                                       fees: RoutingFees { base_msat: 0, proportional_millionths: 0 },
+                                                                       cltv_expiry_delta: 0,
+                                                                       htlc_minimum_msat: None,
+                                                                       htlc_maximum_msat: None,
+                                                               };
+                                                               let candidate = CandidateRouteHop::PrivateHop(PrivateHopCandidate {
+                                                                       hint: &route_hint,
+                                                                       target_node_id: &target_node_id,
+                                                               });
+                                                               scorer.channel_penalty_msat(&candidate, usage, &Default::default());
+                                                       }
+                                                       prev_hop_node = &hop.pubkey;
                                                }
-                                               let network_graph = self.network_graph.read_only();
-                                               if let Some(channel) = network_graph.channel(hop.short_channel_id) {
-                                                       let (directed, _) = channel.as_directed_to(&NodeId::from_pubkey(&hop.pubkey)).unwrap();
-                                                       let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate {
-                                                               info: directed,
-                                                               short_channel_id: hop.short_channel_id,
-                                                       });
-                                                       scorer.channel_penalty_msat(&candidate, usage, &Default::default());
-                                               } else {
-                                                       let target_node_id = NodeId::from_pubkey(&hop.pubkey);
-                                                       let route_hint = RouteHintHop {
-                                                               src_node_id: *prev_hop_node,
-                                                               short_channel_id: hop.short_channel_id,
-                                                               fees: RoutingFees { base_msat: 0, proportional_millionths: 0 },
-                                                               cltv_expiry_delta: 0,
-                                                               htlc_minimum_msat: None,
-                                                               htlc_maximum_msat: None,
-                                                       };
-                                                       let candidate = CandidateRouteHop::PrivateHop(PrivateHopCandidate {
-                                                               hint: &route_hint,
-                                                               target_node_id: &target_node_id,
-                                                       });
-                                                       scorer.channel_penalty_msat(&candidate, usage, &Default::default());
-                                               }
-                                               prev_hop_node = &hop.pubkey;
                                        }
                                }
+                               route_res = res;
+                       } else {
+                               route_res = self.router.find_route(payer, params, first_hops, inflight_htlcs);
                        }
-                       route_res = find_route_res;
                } else {
                        route_res = self.router.find_route(payer, params, first_hops, inflight_htlcs);
                };
@@ -450,7 +459,7 @@ impl WatchtowerPersister {
 }
 
 #[cfg(test)]
-impl<Signer: sign::ecdsa::WriteableEcdsaChannelSigner> chainmonitor::Persist<Signer> for WatchtowerPersister {
+impl<Signer: sign::ecdsa::EcdsaChannelSigner> chainmonitor::Persist<Signer> for WatchtowerPersister {
        fn persist_new_channel(&self, funding_txo: OutPoint,
                data: &channelmonitor::ChannelMonitor<Signer>
        ) -> chain::ChannelMonitorUpdateStatus {
@@ -513,9 +522,6 @@ pub struct TestPersister {
        /// The queue of update statuses we'll return. If none are queued, ::Completed will always be
        /// returned.
        pub update_rets: Mutex<VecDeque<chain::ChannelMonitorUpdateStatus>>,
-       /// When we get an update_persisted_channel call with no ChannelMonitorUpdate, we insert the
-       /// MonitorId here.
-       pub chain_sync_monitor_persistences: Mutex<VecDeque<OutPoint>>,
        /// When we get an update_persisted_channel call *with* a ChannelMonitorUpdate, we insert the
        /// [`ChannelMonitor::get_latest_update_id`] here.
        ///
@@ -526,7 +532,6 @@ impl TestPersister {
        pub fn new() -> Self {
                Self {
                        update_rets: Mutex::new(VecDeque::new()),
-                       chain_sync_monitor_persistences: Mutex::new(VecDeque::new()),
                        offchain_monitor_updates: Mutex::new(new_hash_map()),
                }
        }
@@ -536,7 +541,7 @@ impl TestPersister {
                self.update_rets.lock().unwrap().push_back(next_ret);
        }
 }
-impl<Signer: sign::ecdsa::WriteableEcdsaChannelSigner> chainmonitor::Persist<Signer> for TestPersister {
+impl<Signer: sign::ecdsa::EcdsaChannelSigner> chainmonitor::Persist<Signer> for TestPersister {
        fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor<Signer>) -> chain::ChannelMonitorUpdateStatus {
                if let Some(update_ret) = self.update_rets.lock().unwrap().pop_front() {
                        return update_ret
@@ -552,22 +557,13 @@ impl<Signer: sign::ecdsa::WriteableEcdsaChannelSigner> chainmonitor::Persist<Sig
 
                if let Some(update) = update  {
                        self.offchain_monitor_updates.lock().unwrap().entry(funding_txo).or_insert(new_hash_set()).insert(update.update_id);
-               } else {
-                       self.chain_sync_monitor_persistences.lock().unwrap().push_back(funding_txo);
                }
                ret
        }
 
-       fn archive_persisted_channel(&self, funding_txo: OutPoint) { 
+       fn archive_persisted_channel(&self, funding_txo: OutPoint) {
                // remove the channel from the offchain_monitor_updates map
-               match self.offchain_monitor_updates.lock().unwrap().remove(&funding_txo) {
-                       Some(_) => {},
-                       None => {
-                               // If the channel was not in the offchain_monitor_updates map, it should be in the
-                               // chain_sync_monitor_persistences map.
-                               self.chain_sync_monitor_persistences.lock().unwrap().retain(|x| x != &funding_txo);
-                       }
-               };
+               self.offchain_monitor_updates.lock().unwrap().remove(&funding_txo);
        }
 }
 
@@ -1384,7 +1380,7 @@ impl TestChainSource {
                }
        }
        pub fn remove_watched_txn_and_outputs(&self, outpoint: OutPoint, script_pubkey: ScriptBuf) {
-               self.watched_outputs.lock().unwrap().remove(&(outpoint, script_pubkey.clone())); 
+               self.watched_outputs.lock().unwrap().remove(&(outpoint, script_pubkey.clone()));
                self.watched_txn.lock().unwrap().remove(&(outpoint.txid, script_pubkey));
        }
 }