X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Ftest_utils.rs;h=387f9df677e969def3eca8bde64186213854a3be;hb=cc4bc1df5a1ab6c363fac3c1b7eddce362e167c0;hp=a6c2d77451a411fac9ac1e5f2ba3fffab16ab3d2;hpb=37150b4d697f723f278b6cf3f63892df272b9aa5;p=rust-lightning diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index a6c2d774..387f9df6 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -17,6 +17,7 @@ use crate::chain::chainmonitor::{MonitorUpdateId, UpdateOrigin}; use crate::chain::channelmonitor; use crate::chain::channelmonitor::MonitorEvent; use crate::chain::transaction::OutPoint; +use crate::routing::router::CandidateRouteHop; use crate::sign; use crate::events; use crate::events::bump_transaction::{WalletSource, Utxo}; @@ -139,10 +140,35 @@ impl<'a> Router for TestRouter<'a> { // Since the path is reversed, the last element in our iteration is the first // hop. if idx == path.hops.len() - 1 { - scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(payer), &NodeId::from_pubkey(&hop.pubkey), usage, &Default::default()); + let first_hops = match first_hops { + Some(hops) => hops, + None => continue, + }; + if first_hops.len() == 0 { + continue; + } + let idx = if first_hops.len() > 1 { route.paths.iter().position(|p| p == path).unwrap_or(0) } else { 0 }; + let node_id = NodeId::from_pubkey(payer); + let candidate = CandidateRouteHop::FirstHop { + details: first_hops[idx], + payer_node_id: &node_id, + }; + scorer.channel_penalty_msat(&candidate, usage, &()); } else { - let curr_hop_path_idx = path.hops.len() - 1 - idx; - scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(&path.hops[curr_hop_path_idx - 1].pubkey), &NodeId::from_pubkey(&hop.pubkey), usage, &Default::default()); + let network_graph = self.network_graph.read_only(); + let channel = match network_graph.channel(hop.short_channel_id) { + Some(channel) => channel, + None => continue, + }; + let channel = match channel.as_directed_to(&NodeId::from_pubkey(&hop.pubkey)) { + Some(channel) => channel, + None => panic!("Channel directed to {} was not found", hop.pubkey), + }; + let candidate = CandidateRouteHop::PublicHop { + info: channel.0, + short_channel_id: hop.short_channel_id, + }; + scorer.channel_penalty_msat(&candidate, usage, &()); } } } @@ -1297,8 +1323,12 @@ impl crate::util::ser::Writeable for TestScorer { impl ScoreLookUp for TestScorer { type ScoreParams = (); fn channel_penalty_msat( - &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId, usage: ChannelUsage, _score_params: &Self::ScoreParams + &self, candidate: &CandidateRouteHop, usage: ChannelUsage, _score_params: &Self::ScoreParams ) -> u64 { + let short_channel_id = match candidate.globally_unique_short_channel_id() { + Some(scid) => scid, + None => return 0, + }; if let Some(scorer_expectations) = self.scorer_expectations.borrow_mut().as_mut() { match scorer_expectations.pop_front() { Some((scid, expectation)) => {