X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Ftest_utils.rs;h=9bbbec0d78e6a182e6e8e2553e60642db9645f52;hb=d99089e16a6e7c4744af5dda0750a7c7a17caba6;hp=a6c2d77451a411fac9ac1e5f2ba3fffab16ab3d2;hpb=12c2086d58b54e7d2b1356a3a45d3a605dcd2972;p=rust-lightning diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index a6c2d774..9bbbec0d 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -17,6 +17,7 @@ use crate::chain::chainmonitor::{MonitorUpdateId, UpdateOrigin}; use crate::chain::channelmonitor; use crate::chain::channelmonitor::MonitorEvent; use crate::chain::transaction::OutPoint; +use crate::routing::router::CandidateRouteHop; use crate::sign; use crate::events; use crate::events::bump_transaction::{WalletSource, Utxo}; @@ -139,10 +140,34 @@ impl<'a> Router for TestRouter<'a> { // Since the path is reversed, the last element in our iteration is the first // hop. if idx == path.hops.len() - 1 { - scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(payer), &NodeId::from_pubkey(&hop.pubkey), usage, &Default::default()); + let first_hops = match first_hops { + Some(hops) => hops, + None => continue, + }; + if first_hops.len() == 0 { + continue; + } + let idx = if first_hops.len() > 1 { route.paths.iter().position(|p| p == path).unwrap_or(0) } else { 0 }; + let candidate = CandidateRouteHop::FirstHop { + details: first_hops[idx], + node_id: NodeId::from_pubkey(payer) + }; + scorer.channel_penalty_msat(&candidate, usage, &()); } else { - let curr_hop_path_idx = path.hops.len() - 1 - idx; - scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(&path.hops[curr_hop_path_idx - 1].pubkey), &NodeId::from_pubkey(&hop.pubkey), usage, &Default::default()); + let network_graph = self.network_graph.read_only(); + let channel = match network_graph.channel(hop.short_channel_id) { + Some(channel) => channel, + None => continue, + }; + let channel = match channel.as_directed_to(&NodeId::from_pubkey(&hop.pubkey)) { + Some(channel) => channel, + None => panic!("Channel directed to {} was not found", hop.pubkey), + }; + let candidate = CandidateRouteHop::PublicHop { + info: channel.0, + short_channel_id: hop.short_channel_id, + }; + scorer.channel_penalty_msat(&candidate, usage, &()); } } } @@ -1297,8 +1322,12 @@ impl crate::util::ser::Writeable for TestScorer { impl ScoreLookUp for TestScorer { type ScoreParams = (); fn channel_penalty_msat( - &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId, usage: ChannelUsage, _score_params: &Self::ScoreParams + &self, candidate: &CandidateRouteHop, usage: ChannelUsage, _score_params: &Self::ScoreParams ) -> u64 { + let short_channel_id = match candidate.short_channel_id() { + Some(scid) => scid, + None => return 0, + }; if let Some(scorer_expectations) = self.scorer_expectations.borrow_mut().as_mut() { match scorer_expectations.pop_front() { Some((scid, expectation)) => {