X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Ftest_utils.rs;h=9ff5d76ef8db0fe668d1033cc2ed11af44d16438;hb=336c77c738c792eb8cc1b2ee0e78ff96f106f753;hp=43518a3fd86cf2c9e8984e953d45b657c8b3bb68;hpb=806b7f0e312c59c87fd628fb71e7c4a77a39645a;p=rust-lightning diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 43518a3f..9ff5d76e 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -112,7 +112,7 @@ pub struct TestRouter<'a> { >, //pub entropy_source: &'a RandomBytes, pub network_graph: Arc>, - pub next_routes: Mutex)>>, + pub next_routes: Mutex>)>>, pub scorer: &'a RwLock, } @@ -132,7 +132,12 @@ impl<'a> TestRouter<'a> { pub fn expect_find_route(&self, query: RouteParameters, result: Result) { let mut expected_routes = self.next_routes.lock().unwrap(); - expected_routes.push_back((query, result)); + expected_routes.push_back((query, Some(result))); + } + + pub fn expect_find_route_query(&self, query: RouteParameters) { + let mut expected_routes = self.next_routes.lock().unwrap(); + expected_routes.push_back((query, None)); } } @@ -145,63 +150,67 @@ impl<'a> Router for TestRouter<'a> { let next_route_opt = self.next_routes.lock().unwrap().pop_front(); if let Some((find_route_query, find_route_res)) = next_route_opt { assert_eq!(find_route_query, *params); - if let Ok(ref route) = find_route_res { - assert_eq!(route.route_params, Some(find_route_query)); - let scorer = self.scorer.read().unwrap(); - let scorer = ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs); - for path in &route.paths { - let mut aggregate_msat = 0u64; - let mut prev_hop_node = payer; - for (idx, hop) in path.hops.iter().rev().enumerate() { - aggregate_msat += hop.fee_msat; - let usage = ChannelUsage { - amount_msat: aggregate_msat, - inflight_htlc_msat: 0, - effective_capacity: EffectiveCapacity::Unknown, - }; - - if idx == path.hops.len() - 1 { - if let Some(first_hops) = first_hops { - if let Some(idx) = first_hops.iter().position(|h| h.get_outbound_payment_scid() == Some(hop.short_channel_id)) { - let node_id = NodeId::from_pubkey(payer); - let candidate = CandidateRouteHop::FirstHop(FirstHopCandidate { - details: first_hops[idx], - payer_node_id: &node_id, - }); - scorer.channel_penalty_msat(&candidate, usage, &Default::default()); - continue; + if let Some(res) = find_route_res { + if let Ok(ref route) = res { + assert_eq!(route.route_params, Some(find_route_query)); + let scorer = self.scorer.read().unwrap(); + let scorer = ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs); + for path in &route.paths { + let mut aggregate_msat = 0u64; + let mut prev_hop_node = payer; + for (idx, hop) in path.hops.iter().rev().enumerate() { + aggregate_msat += hop.fee_msat; + let usage = ChannelUsage { + amount_msat: aggregate_msat, + inflight_htlc_msat: 0, + effective_capacity: EffectiveCapacity::Unknown, + }; + + if idx == path.hops.len() - 1 { + if let Some(first_hops) = first_hops { + if let Some(idx) = first_hops.iter().position(|h| h.get_outbound_payment_scid() == Some(hop.short_channel_id)) { + let node_id = NodeId::from_pubkey(payer); + let candidate = CandidateRouteHop::FirstHop(FirstHopCandidate { + details: first_hops[idx], + payer_node_id: &node_id, + }); + scorer.channel_penalty_msat(&candidate, usage, &Default::default()); + continue; + } } } + let network_graph = self.network_graph.read_only(); + if let Some(channel) = network_graph.channel(hop.short_channel_id) { + let (directed, _) = channel.as_directed_to(&NodeId::from_pubkey(&hop.pubkey)).unwrap(); + let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate { + info: directed, + short_channel_id: hop.short_channel_id, + }); + scorer.channel_penalty_msat(&candidate, usage, &Default::default()); + } else { + let target_node_id = NodeId::from_pubkey(&hop.pubkey); + let route_hint = RouteHintHop { + src_node_id: *prev_hop_node, + short_channel_id: hop.short_channel_id, + fees: RoutingFees { base_msat: 0, proportional_millionths: 0 }, + cltv_expiry_delta: 0, + htlc_minimum_msat: None, + htlc_maximum_msat: None, + }; + let candidate = CandidateRouteHop::PrivateHop(PrivateHopCandidate { + hint: &route_hint, + target_node_id: &target_node_id, + }); + scorer.channel_penalty_msat(&candidate, usage, &Default::default()); + } + prev_hop_node = &hop.pubkey; } - let network_graph = self.network_graph.read_only(); - if let Some(channel) = network_graph.channel(hop.short_channel_id) { - let (directed, _) = channel.as_directed_to(&NodeId::from_pubkey(&hop.pubkey)).unwrap(); - let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate { - info: directed, - short_channel_id: hop.short_channel_id, - }); - scorer.channel_penalty_msat(&candidate, usage, &Default::default()); - } else { - let target_node_id = NodeId::from_pubkey(&hop.pubkey); - let route_hint = RouteHintHop { - src_node_id: *prev_hop_node, - short_channel_id: hop.short_channel_id, - fees: RoutingFees { base_msat: 0, proportional_millionths: 0 }, - cltv_expiry_delta: 0, - htlc_minimum_msat: None, - htlc_maximum_msat: None, - }; - let candidate = CandidateRouteHop::PrivateHop(PrivateHopCandidate { - hint: &route_hint, - target_node_id: &target_node_id, - }); - scorer.channel_penalty_msat(&candidate, usage, &Default::default()); - } - prev_hop_node = &hop.pubkey; } } + route_res = res; + } else { + route_res = self.router.find_route(payer, params, first_hops, inflight_htlcs); } - route_res = find_route_res; } else { route_res = self.router.find_route(payer, params, first_hops, inflight_htlcs); }; @@ -450,7 +459,7 @@ impl WatchtowerPersister { } #[cfg(test)] -impl chainmonitor::Persist for WatchtowerPersister { +impl chainmonitor::Persist for WatchtowerPersister { fn persist_new_channel(&self, funding_txo: OutPoint, data: &channelmonitor::ChannelMonitor ) -> chain::ChannelMonitorUpdateStatus { @@ -513,9 +522,6 @@ pub struct TestPersister { /// The queue of update statuses we'll return. If none are queued, ::Completed will always be /// returned. pub update_rets: Mutex>, - /// When we get an update_persisted_channel call with no ChannelMonitorUpdate, we insert the - /// MonitorId here. - pub chain_sync_monitor_persistences: Mutex>, /// When we get an update_persisted_channel call *with* a ChannelMonitorUpdate, we insert the /// [`ChannelMonitor::get_latest_update_id`] here. /// @@ -526,7 +532,6 @@ impl TestPersister { pub fn new() -> Self { Self { update_rets: Mutex::new(VecDeque::new()), - chain_sync_monitor_persistences: Mutex::new(VecDeque::new()), offchain_monitor_updates: Mutex::new(new_hash_map()), } } @@ -536,7 +541,7 @@ impl TestPersister { self.update_rets.lock().unwrap().push_back(next_ret); } } -impl chainmonitor::Persist for TestPersister { +impl chainmonitor::Persist for TestPersister { fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor) -> chain::ChannelMonitorUpdateStatus { if let Some(update_ret) = self.update_rets.lock().unwrap().pop_front() { return update_ret @@ -552,22 +557,13 @@ impl chainmonitor::Persist {}, - None => { - // If the channel was not in the offchain_monitor_updates map, it should be in the - // chain_sync_monitor_persistences map. - self.chain_sync_monitor_persistences.lock().unwrap().retain(|x| x != &funding_txo); - } - }; + self.offchain_monitor_updates.lock().unwrap().remove(&funding_txo); } } @@ -1384,7 +1380,7 @@ impl TestChainSource { } } pub fn remove_watched_txn_and_outputs(&self, outpoint: OutPoint, script_pubkey: ScriptBuf) { - self.watched_outputs.lock().unwrap().remove(&(outpoint, script_pubkey.clone())); + self.watched_outputs.lock().unwrap().remove(&(outpoint, script_pubkey.clone())); self.watched_txn.lock().unwrap().remove(&(outpoint.txid, script_pubkey)); } }