+pub struct TestRouter<'a> {
+ pub network_graph: Arc<NetworkGraph<&'a TestLogger>>,
+ pub next_routes: Mutex<VecDeque<(RouteParameters, Result<Route, LightningError>)>>,
+ pub scorer: &'a Mutex<TestScorer>,
+}
+
+impl<'a> TestRouter<'a> {
+ pub fn new(network_graph: Arc<NetworkGraph<&'a TestLogger>>, scorer: &'a Mutex<TestScorer>) -> Self {
+ Self { network_graph, next_routes: Mutex::new(VecDeque::new()), scorer }
+ }
+
+ pub fn expect_find_route(&self, query: RouteParameters, result: Result<Route, LightningError>) {
+ let mut expected_routes = self.next_routes.lock().unwrap();
+ expected_routes.push_back((query, result));
+ }
+}
+
+impl<'a> Router for TestRouter<'a> {
+ fn find_route(
+ &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&channelmanager::ChannelDetails]>,
+ inflight_htlcs: &InFlightHtlcs
+ ) -> Result<Route, msgs::LightningError> {
+ if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
+ assert_eq!(find_route_query, *params);
+ if let Ok(ref route) = find_route_res {
+ let locked_scorer = self.scorer.lock().unwrap();
+ let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs);
+ for path in &route.paths {
+ let mut aggregate_msat = 0u64;
+ for (idx, hop) in path.iter().rev().enumerate() {
+ aggregate_msat += hop.fee_msat;
+ let usage = ChannelUsage {
+ amount_msat: aggregate_msat,
+ inflight_htlc_msat: 0,
+ effective_capacity: EffectiveCapacity::Unknown,
+ };
+
+ // Since the path is reversed, the last element in our iteration is the first
+ // hop.
+ if idx == path.len() - 1 {
+ scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(payer), &NodeId::from_pubkey(&hop.pubkey), usage);
+ } else {
+ let curr_hop_path_idx = path.len() - 1 - idx;
+ scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(&path[curr_hop_path_idx - 1].pubkey), &NodeId::from_pubkey(&hop.pubkey), usage);
+ }
+ }
+ }
+ }
+ return find_route_res;
+ }
+ let logger = TestLogger::new();
+ let scorer = self.scorer.lock().unwrap();
+ find_route(
+ payer, params, &self.network_graph, first_hops, &logger,
+ &ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs),
+ &[42; 32]
+ )
+ }
+}
+
+impl<'a> Drop for TestRouter<'a> {
+ fn drop(&mut self) {
+ #[cfg(feature = "std")] {
+ if std::thread::panicking() {
+ return;
+ }
+ }
+ assert!(self.next_routes.lock().unwrap().is_empty());
+ }
+}
+