+pub struct TestRouter<'a> {
+ pub router: DefaultRouter<
+ Arc<NetworkGraph<&'a TestLogger>>,
+ &'a TestLogger,
+ Arc<RandomBytes>,
+ &'a RwLock<TestScorer>,
+ (),
+ TestScorer,
+ >,
+ //pub entropy_source: &'a RandomBytes,
+ pub network_graph: Arc<NetworkGraph<&'a TestLogger>>,
+ pub next_routes: Mutex<VecDeque<(RouteParameters, Option<Result<Route, LightningError>>)>>,
+ pub scorer: &'a RwLock<TestScorer>,
+}
+
+impl<'a> TestRouter<'a> {
+ pub fn new(
+ network_graph: Arc<NetworkGraph<&'a TestLogger>>, logger: &'a TestLogger,
+ scorer: &'a RwLock<TestScorer>,
+ ) -> Self {
+ let entropy_source = Arc::new(RandomBytes::new([42; 32]));
+ Self {
+ router: DefaultRouter::new(network_graph.clone(), logger, entropy_source, scorer, ()),
+ network_graph,
+ next_routes: Mutex::new(VecDeque::new()),
+ scorer,
+ }
+ }
+
+ pub fn expect_find_route(&self, query: RouteParameters, result: Result<Route, LightningError>) {
+ let mut expected_routes = self.next_routes.lock().unwrap();
+ expected_routes.push_back((query, Some(result)));
+ }
+
+ pub fn expect_find_route_query(&self, query: RouteParameters) {
+ let mut expected_routes = self.next_routes.lock().unwrap();
+ expected_routes.push_back((query, None));
+ }
+}
+
+impl<'a> Router for TestRouter<'a> {
+ fn find_route(
+ &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
+ inflight_htlcs: InFlightHtlcs
+ ) -> Result<Route, msgs::LightningError> {
+ let route_res;
+ let next_route_opt = self.next_routes.lock().unwrap().pop_front();
+ if let Some((find_route_query, find_route_res)) = next_route_opt {
+ assert_eq!(find_route_query, *params);
+ if let Some(res) = find_route_res {
+ if let Ok(ref route) = res {
+ assert_eq!(route.route_params, Some(find_route_query));
+ let scorer = self.scorer.read().unwrap();
+ let scorer = ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs);
+ for path in &route.paths {
+ let mut aggregate_msat = 0u64;
+ let mut prev_hop_node = payer;
+ for (idx, hop) in path.hops.iter().rev().enumerate() {
+ aggregate_msat += hop.fee_msat;
+ let usage = ChannelUsage {
+ amount_msat: aggregate_msat,
+ inflight_htlc_msat: 0,
+ effective_capacity: EffectiveCapacity::Unknown,
+ };
+
+ if idx == path.hops.len() - 1 {
+ if let Some(first_hops) = first_hops {
+ if let Some(idx) = first_hops.iter().position(|h| h.get_outbound_payment_scid() == Some(hop.short_channel_id)) {
+ let node_id = NodeId::from_pubkey(payer);
+ let candidate = CandidateRouteHop::FirstHop(FirstHopCandidate {
+ details: first_hops[idx],
+ payer_node_id: &node_id,
+ });
+ scorer.channel_penalty_msat(&candidate, usage, &Default::default());
+ continue;
+ }
+ }
+ }
+ let network_graph = self.network_graph.read_only();
+ if let Some(channel) = network_graph.channel(hop.short_channel_id) {
+ let (directed, _) = channel.as_directed_to(&NodeId::from_pubkey(&hop.pubkey)).unwrap();
+ let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate {
+ info: directed,
+ short_channel_id: hop.short_channel_id,
+ });
+ scorer.channel_penalty_msat(&candidate, usage, &Default::default());
+ } else {
+ let target_node_id = NodeId::from_pubkey(&hop.pubkey);
+ let route_hint = RouteHintHop {
+ src_node_id: *prev_hop_node,
+ short_channel_id: hop.short_channel_id,
+ fees: RoutingFees { base_msat: 0, proportional_millionths: 0 },
+ cltv_expiry_delta: 0,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ };
+ let candidate = CandidateRouteHop::PrivateHop(PrivateHopCandidate {
+ hint: &route_hint,
+ target_node_id: &target_node_id,
+ });
+ scorer.channel_penalty_msat(&candidate, usage, &Default::default());
+ }
+ prev_hop_node = &hop.pubkey;
+ }
+ }
+ }
+ route_res = res;
+ } else {
+ route_res = self.router.find_route(payer, params, first_hops, inflight_htlcs);
+ }
+ } else {
+ route_res = self.router.find_route(payer, params, first_hops, inflight_htlcs);
+ };
+
+ if let Ok(route) = &route_res {
+ // Previously, `Route`s failed to round-trip through serialization due to a write/read
+ // mismatch. Thus, here we test all test-generated routes round-trip:
+ let ser = route.encode();
+ assert_eq!(Route::read(&mut &ser[..]).unwrap(), *route);
+ }
+ route_res
+ }
+
+ fn create_blinded_payment_paths<
+ T: secp256k1::Signing + secp256k1::Verification
+ >(
+ &self, recipient: PublicKey, first_hops: Vec<ChannelDetails>, tlvs: ReceiveTlvs,
+ amount_msats: u64, secp_ctx: &Secp256k1<T>,
+ ) -> Result<Vec<(BlindedPayInfo, BlindedPath)>, ()> {
+ self.router.create_blinded_payment_paths(
+ recipient, first_hops, tlvs, amount_msats, secp_ctx
+ )
+ }
+}
+
+impl<'a> MessageRouter for TestRouter<'a> {
+ fn find_path(
+ &self, sender: PublicKey, peers: Vec<PublicKey>, destination: Destination
+ ) -> Result<OnionMessagePath, ()> {
+ self.router.find_path(sender, peers, destination)
+ }
+
+ fn create_blinded_paths<
+ T: secp256k1::Signing + secp256k1::Verification
+ >(
+ &self, recipient: PublicKey, peers: Vec<PublicKey>, secp_ctx: &Secp256k1<T>,
+ ) -> Result<Vec<BlindedPath>, ()> {
+ self.router.create_blinded_paths(recipient, peers, secp_ctx)
+ }
+}
+
+impl<'a> Drop for TestRouter<'a> {
+ fn drop(&mut self) {
+ #[cfg(feature = "std")] {
+ if std::thread::panicking() {
+ return;
+ }
+ }
+ assert!(self.next_routes.lock().unwrap().is_empty());
+ }
+}
+
+pub struct TestMessageRouter<'a> {
+ inner: DefaultMessageRouter<Arc<NetworkGraph<&'a TestLogger>>, &'a TestLogger, &'a TestKeysInterface>,
+}
+
+impl<'a> TestMessageRouter<'a> {
+ pub fn new(network_graph: Arc<NetworkGraph<&'a TestLogger>>, entropy_source: &'a TestKeysInterface) -> Self {
+ Self { inner: DefaultMessageRouter::new(network_graph, entropy_source) }
+ }
+}
+
+impl<'a> MessageRouter for TestMessageRouter<'a> {
+ fn find_path(
+ &self, sender: PublicKey, peers: Vec<PublicKey>, destination: Destination
+ ) -> Result<OnionMessagePath, ()> {
+ self.inner.find_path(sender, peers, destination)
+ }
+
+ fn create_blinded_paths<T: secp256k1::Signing + secp256k1::Verification>(
+ &self, recipient: PublicKey, peers: Vec<PublicKey>, secp_ctx: &Secp256k1<T>,
+ ) -> Result<Vec<BlindedPath>, ()> {
+ self.inner.create_blinded_paths(recipient, peers, secp_ctx)
+ }
+}
+